# stellascript/processing/speaker_manager.py
# stellascript/processing/speaker_manager.py
"""
Manages speaker identification and embedding storage.
This module is responsible for loading a speaker recognition model, generating
embeddings for audio segments, and assigning speaker IDs based on similarity.
It maintains a registry of known speakers and their corresponding embeddings.
"""
import os
from typing import Dict, List, Optional, Union
import numpy as np
import torch
from speechbrain.inference import SpeakerRecognition
from ..logging_config import get_logger
logger = get_logger(__name__)
[docs]
class SpeakerManager:
"""
Handles speaker embeddings and identification.
This class uses a pre-trained speaker recognition model to create vector
embeddings from audio segments. It can then compare these embeddings to
identify known speakers or register new ones.
"""
def __init__(self, device: torch.device, similarity_threshold: float) -> None:
"""
Initializes the SpeakerManager.
Args:
device (torch.device): The device to run the model on (e.g., 'cuda' or 'cpu').
similarity_threshold (float): The cosine similarity threshold for
identifying a speaker.
"""
self.device: torch.device = device
self.similarity_threshold: float = similarity_threshold
self.embedding_model: SpeakerRecognition = self._load_speaker_embedding_model()
self.speaker_embeddings_normalized: Dict[str, np.ndarray] = {}
self.next_speaker_id: int = 1
def _load_speaker_embedding_model(self) -> SpeakerRecognition:
"""
Loads the speaker embedding model.
This method includes a workaround for a known symlink issue on Windows
by temporarily changing the local fetching strategy in SpeechBrain.
Returns:
SpeakerRecognition: The loaded speaker recognition model.
Raises:
RuntimeError: If the model fails to load for any reason.
"""
os.environ["SPEECHBRAIN_CACHE_DIR"] = os.path.join(
os.getcwd(), "speechbrain_cache"
)
embedding_model = None
try:
embedding_model = SpeakerRecognition.from_hparams(
source="speechbrain/spkrec-ecapa-voxceleb"
)
except OSError as e:
if "privilège nécessaire" in str(e) or "WinError 1314" in str(e):
logger.info("Windows symlink issue detected - using copy strategy")
import speechbrain.utils.fetching
original_strategy = getattr(speechbrain.utils.fetching, "LOCAL_STRATEGY", None)
copy_strategy_class = getattr(speechbrain.utils.fetching, "CopyStrategy")
setattr(speechbrain.utils.fetching, "LOCAL_STRATEGY", copy_strategy_class())
try:
embedding_model = SpeakerRecognition.from_hparams(
source="speechbrain/spkrec-ecapa-voxceleb"
)
finally:
if original_strategy:
setattr(speechbrain.utils.fetching, "LOCAL_STRATEGY", original_strategy)
else:
raise e
if embedding_model:
if self.device.type == "cuda":
embedding_model = embedding_model.half()
return embedding_model.to(self.device)
raise RuntimeError("Failed to load speaker embedding model.")
[docs]
def get_speaker_id(self, embedding: Union[np.ndarray, torch.Tensor]) -> Optional[str]:
"""
Gets or assigns a speaker ID based on embedding similarity.
Compares the provided embedding with stored embeddings of known speakers.
If a match is found above the similarity threshold, the existing speaker ID
is returned. Otherwise, a new speaker is registered.
Args:
embedding (Union[np.ndarray, torch.Tensor]): The speaker embedding to identify.
Returns:
Optional[str]: The assigned speaker ID, or None if the embedding is invalid.
"""
if isinstance(embedding, torch.Tensor):
embedding = embedding.cpu().numpy()
embedding = np.array(embedding, dtype=np.float32, copy=True)
embedding_flat = embedding.flatten()
norm = np.linalg.norm(embedding_flat)
if norm == 0:
logger.debug("Zero norm embedding detected - skipping")
return None
embedding_norm = embedding_flat / norm
if not self.speaker_embeddings_normalized:
speaker_id = f"SPEAKER_{self.next_speaker_id:02d}"
self.speaker_embeddings_normalized[speaker_id] = embedding_norm
self.next_speaker_id += 1
logger.info(f"Registered first speaker as {speaker_id}.")
return speaker_id
best_match: Optional[str] = None
best_similarity: float = -1.0
for speaker_id, stored_embedding_norm in self.speaker_embeddings_normalized.items():
similarity = float(np.dot(embedding_norm, stored_embedding_norm))
logger.debug(f"Comparing with {speaker_id}: similarity = {similarity:.4f}")
if similarity > best_similarity:
best_similarity = similarity
best_match = speaker_id
logger.debug(f"Best match: {best_match} with similarity {best_similarity:.4f} (threshold: {self.similarity_threshold})")
if best_similarity > self.similarity_threshold and best_match is not None:
logger.info(f"Segment assigned to existing speaker {best_match} with similarity {best_similarity:.4f}.")
existing_embedding = self.speaker_embeddings_normalized[best_match]
weight = 0.7
updated_embedding = (weight * existing_embedding) + ((1 - weight) * embedding_norm)
updated_norm = np.linalg.norm(updated_embedding)
if updated_norm > 0:
self.speaker_embeddings_normalized[best_match] = updated_embedding / updated_norm
logger.debug(f"Updated embedding for {best_match}.")
return best_match
else:
speaker_id = f"SPEAKER_{self.next_speaker_id:02d}"
self.speaker_embeddings_normalized[speaker_id] = embedding_norm
self.next_speaker_id += 1
logger.info(f"Similarity {best_similarity:.4f} is below threshold. Registered new speaker: {speaker_id}")
return speaker_id
[docs]
def get_embeddings(self, audio_segments: List[np.ndarray]) -> np.ndarray:
"""
Gets embeddings for a batch of audio segments.
This method attempts to process all segments in a single batch for efficiency.
If batch processing fails, it falls back to processing segments one by one.
Args:
audio_segments (List[np.ndarray]): A list of audio segments as NumPy arrays.
Returns:
np.ndarray: A NumPy array of embeddings for the processed segments.
"""
if not audio_segments:
return np.array([])
try:
max_len = max(len(seg) for seg in audio_segments)
padded_segments = [
np.pad(seg, (0, max_len - len(seg)), mode='constant') if len(seg) < max_len else seg
for seg in audio_segments
]
batch_tensor = torch.stack([
torch.from_numpy(seg).float() for seg in padded_segments
]).to(self.device)
batch_embeddings = self.embedding_model.encode_batch(batch_tensor)
return batch_embeddings.cpu().numpy()
except Exception as e:
logger.debug(f"Batch embedding failed: {e}. Falling back to sequential processing.")
embeddings = []
for seg in audio_segments:
try:
tensor = torch.from_numpy(seg).float().unsqueeze(0).to(self.device)
embedding = self.embedding_model.encode_batch(tensor)
embeddings.append(embedding.cpu().numpy())
except Exception as e2:
logger.debug(f"Failed to process segment in fallback: {e2}")
if not embeddings:
return np.array([])
return np.concatenate(embeddings, axis=0)