Source code for src.clients.transformers_client

"""Transformers client for local embedding generation.

This module provides a client for generating embeddings using the Hugging Face
transformers library directly, with special handling for Jina models.

Example:
    Generate embeddings using Transformers::

        from src.clients.transformers_client import TransformersClient

        embeddings = TransformersClient.get_embeddings(
            texts=["Hello world"],
            model_name="BAAI/bge-small-en-v1.5"
        )
"""

from typing import Dict, Tuple

import torch
from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer


[docs] def mean_pooling( token_embeddings: torch.Tensor, attention_mask: torch.Tensor ) -> torch.Tensor: """Perform mean pooling on token embeddings. Args: token_embeddings: Tensor of token-level embeddings. attention_mask: Attention mask for the input tokens. Returns: Mean-pooled sentence embeddings tensor. """ input_mask_expanded = ( attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() ) return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( input_mask_expanded.sum(1), min=1e-9 )
[docs] class TransformersClient: """Client for managing local transformers embedding models. Implements singleton pattern for model instances with special handling for Jina models and their task labels. Attributes: _instances: Class-level cache of loaded model and tokenizer instances. """ _instances: Dict[str, Tuple[PreTrainedModel, PreTrainedTokenizer]] = {}
[docs] @classmethod def get_instance( cls, model_name: str ) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: """Get or create a transformers model and tokenizer instance. Args: model_name: Name of the transformers model. Returns: Tuple of (model, tokenizer) instances. """ if model_name not in cls._instances: tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModel.from_pretrained(model_name, trust_remote_code=True) cls._instances[model_name] = (model, tokenizer) return cls._instances[model_name]
[docs] @classmethod def get_embeddings( cls, texts: list[str], model_name: str, expected_dimension: int | None = None ) -> list[list[float]]: """Generate embeddings using a local transformers model. Handles special cases for Jina models including task labels and different output formats. Args: texts: List of texts to embed. model_name: Name of the transformers model. expected_dimension: Expected embedding dimension for validation. Returns: List of normalized embedding vectors as lists of floats. Raises: ValueError: If embedding dimension doesn't match expected or embeddings cannot be extracted. """ model, tokenizer = cls.get_instance(model_name) # Tokenize sentences encoded_input = tokenizer( texts, padding=True, truncation=True, return_tensors="pt" ) # Compute token embeddings with torch.no_grad(): if "jina" in model_name.lower(): # Les modèles Jina nécessitent obligatoirement task_label try: model_output = model(**encoded_input, task_label="text-matching") except Exception: # Fallback sans task_label pour les anciens modèles Jina model_output = model(**encoded_input) # Pour Jina v4, utiliser single_vec_emb qui contient déjà les embeddings poolés if ( hasattr(model_output, "single_vec_emb") and model_output.single_vec_emb is not None ): sentence_embeddings = model_output.single_vec_emb sentence_embeddings = torch.nn.functional.normalize( sentence_embeddings, p=2, dim=1 ) embeddings = sentence_embeddings.tolist() if expected_dimension and embeddings: actual_dimension = len(embeddings[0]) if actual_dimension != expected_dimension: raise ValueError( f"Expected dimension {expected_dimension}, but got {actual_dimension} for model {model_name}" ) return embeddings # Fallback pour les anciens modèles Jina ou si single_vec_emb n'est pas disponible token_embeddings = None if ( hasattr(model_output, "last_hidden_state") and model_output.last_hidden_state is not None ): token_embeddings = model_output.last_hidden_state elif ( hasattr(model_output, "vlm_last_hidden_states") and model_output.vlm_last_hidden_states is not None ): token_embeddings = model_output.vlm_last_hidden_states elif ( hasattr(model_output, "pooler_output") and model_output.pooler_output is not None ): sentence_embeddings = model_output.pooler_output sentence_embeddings = torch.nn.functional.normalize( sentence_embeddings, p=2, dim=1 ) embeddings = sentence_embeddings.tolist() if expected_dimension and embeddings: actual_dimension = len(embeddings[0]) if actual_dimension != expected_dimension: raise ValueError( f"Expected dimension {expected_dimension}, but got {actual_dimension} for model {model_name}" ) return embeddings if token_embeddings is None: raise ValueError( f"Unable to extract embeddings from Jina model output for {model_name}. Available attributes: {[attr for attr in dir(model_output) if not attr.startswith('_')]}" )