Source code for src.clients.api_client

"""API client for production embedding services.

This module provides a client for obtaining embeddings from production APIs
including OpenAI, Mistral, and VoyageAI. It handles authentication, batching,
and automatic retry with batch size reduction on errors.

Example:
    Get embeddings from an API::

        from src.clients.api_client import ProductionEmbeddingClient

        client = ProductionEmbeddingClient(
            base_url="https://api.openai.com/v1",
            model="text-embedding-ada-002",
            expected_dimension=1536
        )
        embeddings = client.get_embeddings(["Hello", "World"])
"""

import json
import os
from typing import List

import requests
from dotenv import load_dotenv
from tqdm import tqdm

load_dotenv()


[docs] class ProductionEmbeddingClient: """Client for obtaining embeddings from production APIs. Supports OpenAI-compatible APIs with automatic API key selection based on the model name. Implements automatic batch splitting and retries. Attributes: base_url: Base URL of the API. model: Name of the embedding model. expected_dimension: Expected embedding dimension for validation. timeout: Request timeout in seconds. session: Requests session with authentication headers. """
[docs] def __init__( self, base_url: str, model: str, expected_dimension: int | None = None, timeout: int = 30, initial_batch_size: int | None = None, ) -> None: """Initialize the ProductionEmbeddingClient. Args: base_url: Base URL of the API. model: Name of the embedding model to use. expected_dimension: Expected dimension of embeddings for validation. timeout: Timeout for requests in seconds. initial_batch_size: Initial batch size for requests. """ self.base_url = base_url self.model = model self.expected_dimension = expected_dimension self.timeout = timeout self.session = requests.Session() self.session.headers.update({"Content-Type": "application/json"}) self._initial_batch_size = initial_batch_size # Determines which API key to use based on the model name if "mistral" in model.lower(): api_key = os.environ.get("API_KEY_MISTRAL") elif "voyage" in model.lower(): api_key = os.environ.get("API_KEY_VOYAGEAI") else: api_key = os.environ.get("API_KEY_OPENAI") if api_key: self.session.headers.update({"Authorization": f"Bearer {api_key}"})
[docs] def get_embeddings(self, texts: List[str]) -> List[List[float]]: """Retrieve embeddings for a list of texts via the API. Implements automatic batch splitting for large requests. Args: texts: List of texts to embed. Returns: List of embedding vectors as lists of floats. """ if not texts: return [] batch_size = ( self._initial_batch_size if self._initial_batch_size is not None else len(texts) ) # Start with full batch, will be subdivided if needed return self._get_embeddings_with_retry(texts, initial_batch_size=batch_size)
def _get_embeddings_with_retry( self, texts: List[str], initial_batch_size: int, max_retries: int = 3 ) -> List[List[float]]: """Handle batch subdivision and retries for embedding requests. Automatically reduces batch size on 400 errors related to batch limits. Args: texts: List of texts to embed. initial_batch_size: Starting batch size. max_retries: Maximum number of retry attempts. Returns: List of embedding vectors. """ batch_size = min(initial_batch_size, len(texts)) total_embeddings = [] for i in range(0, len(texts), batch_size): batch_texts = texts[i : i + batch_size] for attempt in range(max_retries): try: embeddings = self._single_api_call(batch_texts) total_embeddings.extend(embeddings) break # Success, move to next batch except requests.exceptions.HTTPError as e: if e.response.status_code == 400: try: error_response = e.response.json() error_message = error_response.get("message", "").lower() # Check if it's a batch size error if any( keyword in error_message for keyword in [ "too many inputs", "split into more batches", "batch size", "request too large", ] ): # Reduce batch size by half new_batch_size = max(1, len(batch_texts) // 2) tqdm.write( f"🔄 Batch too large ({len(batch_texts)} texts), " f"splitting into smaller batches of {new_batch_size}" ) # Recursively process with smaller batches sub_embeddings = self._get_embeddings_with_retry( batch_texts, new_batch_size, max_retries ) total_embeddings.extend(sub_embeddings) break # Success with subdivision else: # Other 400 error, don't retry raise except (json.JSONDecodeError, KeyError): # Can't parse error response, don't retry raise else: # Non-400 error, don't retry raise except requests.exceptions.RequestException as e: if attempt == max_retries - 1: # Last attempt failed error_msg = f"❌ API Error after {max_retries} attempts: {e}" if hasattr(e, "response") and e.response is not None: error_msg += f"\n Status code: {e.response.status_code}" error_msg += ( f"\n URL: {getattr(e.response, 'url', 'unknown')}" ) error_msg += f"\n Response content: {e.response.text}" tqdm.write(error_msg) return [] return total_embeddings def _single_api_call(self, texts: List[str]) -> List[List[float]]: """Make a single API call without retry logic. Args: texts: List of texts to embed in this call. Returns: List of embedding vectors. Raises: requests.exceptions.HTTPError: On API errors. ValueError: If embedding dimension doesn't match expected. """ url = f"{self.base_url}/embeddings" payload = {"model": self.model, "input": texts} try: # Make the API request response = self.session.post(url, json=payload, timeout=self.timeout) response.raise_for_status() result = response.json() embeddings = [data["embedding"] for data in result["data"]] except requests.exceptions.RequestException as e: tqdm.write(f"❌ API request failed: {e}") if hasattr(e, "response") and e.response is not None: tqdm.write(f" Status code: {e.response.status_code}") tqdm.write(f" URL: {getattr(e.response, 'url', 'unknown')}") tqdm.write(f" Response content: {e.response.text}") # Also log the full response for debugging if hasattr(e, "response") and e.response is not None: try: tqdm.write(f" Full response JSON: {e.response.json()}") except json.JSONDecodeError: tqdm.write(" Could not decode JSON from response.") return [] # Return empty embeddings if self.expected_dimension and embeddings: actual_dimension = len(embeddings[0]) if actual_dimension != self.expected_dimension: raise ValueError( f"Expected dimension {self.expected_dimension}, but got {actual_dimension} for model {self.model}" ) return embeddings