# stellascript/processing/diarizer.py
"""
Handles speaker diarization using different methods like Pyannote and VAD with clustering.
"""
import os
import traceback
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import torch
from pyannote.audio import Pipeline
from pyannote.core import Segment
from ..logging_config import get_logger
from .speaker_manager import SpeakerManager
logger = get_logger(__name__)
[docs]
class Diarizer:
"""
A class to perform speaker diarization on audio data.
This class supports multiple diarization methods, including the pre-trained
Pyannote pipeline and a custom VAD-based clustering approach.
"""
def __init__(self, device: torch.device, method: str, hf_token: Optional[str], rate: int) -> None:
"""
Initializes the Diarizer.
Args:
device (torch.device): The device to run the models on.
method (str): The diarization method to use ('pyannote', 'cluster').
hf_token (Optional[str]): The Hugging Face authentication token for Pyannote.
rate (int): The sample rate of the audio.
"""
self.device: torch.device = device
self.method: str = method
self.rate: int = rate
self.diarization_pipeline: Optional[Pipeline] = self._load_diarization_pipeline(hf_token)
self.vad_model: Optional[Any] = None
self.vad_utils: Optional[Dict[str, Any]] = None
def _load_diarization_pipeline(self, hf_token: Optional[str]) -> Optional[Pipeline]:
"""
Load the Pyannote diarization pipeline.
Args:
hf_token (Optional[str]): The Hugging Face token.
Returns:
Optional[Pipeline]: The loaded Pyannote pipeline, or None if not used.
"""
if self.method != "pyannote":
return None
logger.info("Loading pyannote diarization model")
try:
pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
use_auth_token=hf_token
)
if pipeline is None:
raise RuntimeError("Pipeline.from_pretrained returned None.")
if self.device.type == "cuda":
pipeline.model = pipeline.model.half()
logger.info("Pyannote diarization model loaded successfully")
return pipeline.to(self.device)
except AttributeError as e:
if "'NoneType' object has no attribute 'eval'" in str(e):
raise RuntimeError(
"Failed to load diarization pipeline. Check HUGGING_FACE_TOKEN and user agreement."
) from e
else:
raise e
def _ensure_vad_loaded(self) -> None:
"""Lazy loading of the Silero VAD model."""
if self.vad_model is None:
logger.info("Loading Silero VAD model")
self.vad_model, self.vad_utils = torch.hub.load( # type: ignore
repo_or_dir="snakers4/silero-vad",
model="silero_vad",
force_reload=False
)
[docs]
def diarize_pyannote(self, audio_data: np.ndarray, min_speakers: Optional[int] = None, max_speakers: Optional[int] = None) -> List[Tuple[Segment, str, str]]:
"""
Diarize audio using the Pyannote pipeline.
Args:
audio_data (np.ndarray): The audio data to diarize.
min_speakers (Optional[int]): The minimum number of speakers.
max_speakers (Optional[int]): The maximum number of speakers.
Returns:
List[Tuple[Segment, str, str]]: A list of diarized segments.
"""
if self.diarization_pipeline is None:
raise RuntimeError(
"Diarization pipeline not initialized. Check if method is 'pyannote'."
)
diarization_params = {}
if min_speakers is not None:
diarization_params["min_speakers"] = min_speakers
if max_speakers is not None:
diarization_params["max_speakers"] = max_speakers
param_log = ", ".join([f"{k}={v}" for k, v in diarization_params.items()])
logger.info(f"Processing audio with pyannote diarization ({param_log if param_log else 'default params'})")
diarization_result = self.diarization_pipeline(
{"waveform": torch.from_numpy(audio_data).unsqueeze(0), "sample_rate": self.rate},
**diarization_params
)
annotation = diarization_result
if not hasattr(annotation, "itertracks"):
# The result is not an annotation object, so it's likely a container.
# Let's find the annotation object within its attributes.
for value in vars(annotation).values():
if hasattr(value, "itertracks"):
annotation = value
break
else:
# This else block runs if the loop completes without finding an annotation.
raise TypeError(f"Could not find an annotation object in the diarization result: {diarization_result}")
segments_list = list(annotation.itertracks(yield_label=True))
return segments_list
[docs]
def diarize_cluster(self, audio_data: np.ndarray, speaker_manager: SpeakerManager, similarity_threshold: float, max_speakers: Optional[int] = None) -> Tuple[List[Dict[str, Any]], int]:
"""
Diarize audio using VAD and clustering.
Args:
audio_data (np.ndarray): The audio data to diarize.
speaker_manager (SpeakerManager): The speaker manager for embeddings.
similarity_threshold (float): The similarity threshold for clustering.
max_speakers (Optional[int]): The maximum number of speakers.
Returns:
Tuple[List[Dict[str, Any]], int]: A tuple containing the list of
diarized segments and the number
of found speakers.
"""
logger.info("Segmenting speech with Silero VAD")
self._ensure_vad_loaded()
assert self.vad_utils is not None, "VAD utils should be loaded by _ensure_vad_loaded"
get_speech_timestamps = self.vad_utils['get_speech_timestamps']
speech_timestamps = get_speech_timestamps(
torch.from_numpy(audio_data), self.vad_model, sampling_rate=self.rate
)
logger.info(f"VAD found {len(speech_timestamps)} potential speech segments")
logger.info("Filtering and preparing segments for speaker embedding...")
try:
from sklearn.cluster import AgglomerativeClustering
from sklearn.preprocessing import normalize
except ImportError:
raise ImportError("scikit-learn is required. Please run: pip install scikit-learn")
valid_segments_info = []
audio_segments_for_batch = []
for ts in speech_timestamps:
start_samples, end_samples = ts["start"], ts["end"]
audio_segment = audio_data[start_samples:end_samples]
if len(audio_segment) < self.rate * 0.5: continue
valid_segments_info.append({
"turn": Segment(start_samples / self.rate, end_samples / self.rate),
"audio_segment": audio_segment
})
audio_segments_for_batch.append(audio_segment)
if not valid_segments_info:
logger.warning("No valid speech segments found after filtering.")
return [], 0
logger.info(f"Extracted {len(audio_segments_for_batch)} valid segments for embedding.")
all_embeddings = speaker_manager.get_embeddings(audio_segments_for_batch)
if all_embeddings.size == 0:
logger.warning("Speaker embedding extraction resulted in no embeddings.")
return [], 0
if all_embeddings.ndim == 3 and all_embeddings.shape[1] == 1:
all_embeddings = all_embeddings.squeeze(1)
logger.info("Identifying speakers with Agglomerative Clustering...")
normalized_embeddings = normalize(all_embeddings, norm="l2", axis=1)
clustering_params: Dict[str, Any] = {
"metric": "cosine",
"linkage": "average"
}
if max_speakers:
clustering_params["n_clusters"] = max_speakers
logger.info(f"Clustering with a fixed number of speakers: {max_speakers}")
else:
clustering_params["distance_threshold"] = 1 - similarity_threshold
logger.info(f"Clustering with similarity threshold: {similarity_threshold}")
clustering = AgglomerativeClustering(**clustering_params).fit(normalized_embeddings)
cluster_labels = clustering.labels_
found_speakers = len(set(cluster_labels))
logger.info(f"Clustering identified {found_speakers} unique speakers.")
segments_with_speakers = [{
"speaker_label": f"SPEAKER_{label:02d}",
"turn": info["turn"],
"audio_segment": info["audio_segment"]
} for info, label in zip(valid_segments_info, cluster_labels)]
return segments_with_speakers, found_speakers
[docs]
def apply_vad_to_chunk(self, audio_chunk: np.ndarray) -> float:
"""
Apply VAD to a small audio chunk for live subtitle mode.
Args:
audio_chunk (np.ndarray): The audio chunk to process.
Returns:
float: The speech probability.
"""
self._ensure_vad_loaded()
assert self.vad_model is not None, "VAD model should be loaded by _ensure_vad_loaded"
audio_tensor = torch.from_numpy(np.copy(audio_chunk))
speech_prob = self.vad_model(audio_tensor, self.rate).item()
return speech_prob