Source code for stellascript.audio.enhancement

# stellascript/audio/enhancement.py

"""
Handles audio enhancement using various methods like DeepFilterNet and Demucs.
"""

import warnings
from typing import Any, Optional

import numpy as np
import torch
import torchaudio

from ..logging_config import get_logger

logger = get_logger(__name__)


[docs] class AudioEnhancer: """ A class to apply audio enhancement techniques to audio data. This class supports multiple enhancement methods and handles the loading of the necessary models. """ def __init__(self, enhancement_method: str, device: torch.device, rate: int) -> None: """ Initializes the AudioEnhancer. Args: enhancement_method (str): The enhancement method to use ('none', 'deepfilternet', 'demucs'). device (torch.device): The device to run the models on (CPU or CUDA). rate (int): The sample rate of the input audio. """ self.enhancement_method: str = enhancement_method self.device: torch.device = device self.rate: int = rate self.demucs_model: Optional[Any] = None self.df_model: Optional[Any] = None self.df_state: Optional[Any] = None
[docs] def apply(self, audio_data: np.ndarray, is_live: bool = False) -> np.ndarray: """ Apply the selected audio enhancement method. Args: audio_data (np.ndarray): The input audio data as a NumPy array. is_live (bool): Flag indicating if the processing is for a live stream. Returns: np.ndarray: The enhanced audio data. """ if self.enhancement_method == "none": return audio_data if self.enhancement_method == "deepfilternet": from df.enhance import enhance, init_df if self.df_model is None: logger.info("Loading DeepFilterNet denoiser model") try: self.df_model, self.df_state, _ = init_df() except Exception as e: logger.warning(f"Failed to load DeepFilterNet model: {e} - skipping enhancement") return audio_data # Convert to torch tensor, add channel dimension for mono audio audio_tensor = torch.from_numpy(audio_data.copy()).float() if audio_tensor.dim() == 1: audio_tensor = audio_tensor.unsqueeze(0) if self.device == "cuda": audio_tensor = audio_tensor.to(self.device) # Resample to 48kHz for DeepFilterNet if necessary if self.rate != 48000: resampler = torchaudio.transforms.Resample( orig_freq=self.rate, new_freq=48000 ).to(self.device) audio_tensor = resampler(audio_tensor) # Enhance the audio if self.df_model is None or self.df_state is None: warnings.warn("DeepFilterNet model not loaded. Skipping enhancement.") return audio_data enhanced_audio = enhance(self.df_model, self.df_state, audio_tensor) # Resample back to the original rate if necessary if self.rate != 48000: resampler_back = torchaudio.transforms.Resample( orig_freq=48000, new_freq=self.rate ).to(self.device) enhanced_audio = resampler_back(enhanced_audio) # Convert back to numpy array and remove channel dimension return enhanced_audio.squeeze(0).cpu().numpy() elif self.enhancement_method == "demucs": if is_live: warnings.warn("Demucs is not recommended for live processing due to high latency. Using it anyway.") try: from demucs.pretrained import get_model from demucs.apply import apply_model except ImportError as e: logger.error(f"Demucs import failed: {e}") logger.warning("Demucs not installed - please run 'uv sync' - skipping enhancement") return audio_data if self.demucs_model is None: logger.info("Loading Demucs model for audio separation") self.demucs_model = get_model('htdemucs') self.demucs_model.to(self.device) self.demucs_model.eval() if self.rate != 44100: audio_tensor = torch.from_numpy(audio_data).float() if audio_tensor.dim() == 1: audio_tensor = audio_tensor.unsqueeze(0) resampler = torchaudio.transforms.Resample( orig_freq=self.rate, new_freq=44100 ).to(self.device) audio_tensor = resampler(audio_tensor.to(self.device)) else: audio_tensor = torch.from_numpy(audio_data).float() if audio_tensor.dim() == 1: audio_tensor = audio_tensor.unsqueeze(0) if audio_tensor.shape[0] == 1: audio_tensor = audio_tensor.repeat(2, 1) audio_tensor = audio_tensor.unsqueeze(0).to(self.device) logger.info("Applying Demucs model for audio separation") with torch.no_grad(): sources = apply_model( self.demucs_model, audio_tensor, split=True, overlap=0.25 ) vocals = sources[0, 3] vocals_mono_tensor = vocals.mean(dim=0) if self.rate != 44100: resampler_back = torchaudio.transforms.Resample( orig_freq=44100, new_freq=self.rate ).to(self.device) vocals_mono_tensor = resampler_back(vocals_mono_tensor.unsqueeze(0)).squeeze() max_val = torch.max(torch.abs(vocals_mono_tensor)) if max_val > 0: vocals_mono_tensor = vocals_mono_tensor / max_val vocals_mono = vocals_mono_tensor.cpu().numpy().astype(np.float32) logger.info("Demucs audio separation completed") return vocals_mono else: warnings.warn(f"Audio enhancement method '{self.enhancement_method}' is not yet implemented. Audio will not be processed.") return audio_data