Source code for stellascript.processing.transcriber

# stellascript/processing/transcriber.py

"""
Handles audio transcription using the WhisperX library.

This module provides a Transcriber class that encapsulates the logic for loading
a Whisper model and using it to transcribe audio segments. It supports generating
both full-text transcriptions and detailed word-level timestamps.
"""

from typing import Any, List, Optional, Tuple, Union

import numpy as np
import torch
import whisperx
from logging import getLogger

logger = getLogger(__name__)


[docs] class Transcriber: """ A wrapper for the WhisperX transcription model. This class manages the loading of the WhisperX model and provides a simple interface to transcribe audio data. It can be configured for different model sizes, languages, and devices. """ def __init__(self, model_id: str, device: torch.device, language: str) -> None: """ Initializes the Transcriber. Args: model_id (str): The identifier of the Whisper model to use (e.g., 'large-v3'). device (torch.device): The device to run the model on ('cuda' or 'cpu'). language (str): The language of the audio to be transcribed. """ self.device: torch.device = device self.language: str = language self.model_id: str = model_id self.whisperx_model: Optional[Any] = None self._load_models() def _load_models(self) -> None: """ Loads the WhisperX transcription model. This method configures the model to use optimized compute types based on the available device (e.g., 'float16' for CUDA, 'int8' for CPU). Raises: RuntimeError: If the model fails to load. """ try: logger.info(f"Loading whisperx model '{self.model_id}'...") compute_type = "float16" if self.device.type == "cuda" else "int8" device_str = self.device.type self.whisperx_model = whisperx.load_model( self.model_id, device_str, compute_type=compute_type, language=self.language ) logger.info(f"WhisperX model '{self.model_id}' loaded successfully on {device_str} with {compute_type} compute type.") except Exception as e: logger.exception(f"Failed to load whisperx model '{self.model_id}'.") raise RuntimeError(f"Failed to load whisperx model '{self.model_id}'.") from e
[docs] def transcribe_segment( self, audio_data: np.ndarray, rate: int, padding_duration: float, word_timestamps: bool = False, ) -> Union[str, Tuple[List[Any], str]]: """ Transcribes a single audio segment. The audio segment is padded with silence to improve transcription accuracy at the beginning and end of the speech. Args: audio_data (np.ndarray): The raw audio data of the segment. rate (int): The sample rate of the audio. padding_duration (float): The duration of silence padding in seconds. word_timestamps (bool): If True, returns word-level timestamps. Returns: Union[str, Tuple[List[Any], str]]: - If `word_timestamps` is False, returns the transcribed text as a string. - If `word_timestamps` is True, returns a tuple containing a list of segment objects (with word details) and the full transcribed text. """ audio_duration = len(audio_data) / rate if audio_duration < 0.5: logger.debug(f"Segment too short ({audio_duration:.2f}s), skipping") return "" if not word_timestamps else ([], "") padding_samples = int(padding_duration * rate) silence_padding = np.zeros(padding_samples, dtype=np.float32) padded_audio = np.concatenate([silence_padding, audio_data, silence_padding]) return self._transcribe_with_whisperx(padded_audio, rate, word_timestamps)
def _transcribe_with_whisperx( self, audio_data: np.ndarray, rate: int, word_timestamps: bool = False, ) -> Union[str, Tuple[List[Any], str]]: """ Internal method to perform transcription using the WhisperX model. If word timestamps are requested, it also performs alignment to get precise timings for each word. Args: audio_data (np.ndarray): The audio data to transcribe. rate (int): The sample rate of the audio. word_timestamps (bool): Whether to generate word-level timestamps. Returns: Union[str, Tuple[List[Any], str]]: The transcription result. """ if self.whisperx_model is None: logger.warning("WhisperX model not loaded. Returning empty transcription.") return "" if not word_timestamps else ([], "") result = self.whisperx_model.transcribe(audio_data, batch_size=16) if not word_timestamps: full_text = " ".join(seg.get("text", "") for seg in result.get("segments", [])).strip() return full_text try: device_str = self.device.type model_a, metadata = whisperx.load_align_model(language_code=self.language, device=device_str) result = whisperx.align(result["segments"], model_a, metadata, audio_data, device_str, return_char_alignments=False) except Exception as e: logger.error(f"Failed to align transcription: {e}. Returning non-aligned segments.") full_text = "".join(seg.get("text", "") for seg in result.get("segments", [])).strip() return [], full_text segments = [] full_text = "" for segment in result.get("segments", []): words = [] for word_info in segment.get("words", []): if not isinstance(word_info, dict): continue word, start, end = word_info.get("word"), word_info.get("start"), word_info.get("end") if word is None or start is None or end is None: continue word_obj = type('Word', (), {'word': word, 'start': float(start), 'end': float(end)})() words.append(word_obj) segment_text = segment.get("text", "").strip() segment_start, segment_end = segment.get("start"), segment.get("end") if segment_start is None or segment_end is None: continue segment_obj = type('Segment', (), {'text': segment_text, 'words': words, 'start': float(segment_start), 'end': float(segment_end)})() segments.append(segment_obj) full_text += segment_text + " " return segments, full_text.strip()