Source code for src.services.visualization_service

"""Visualization service for ForzaEmbed.

This module provides the VisualizationService class that handles
dimensionality reduction (t-SNE, UMAP, PCA) and caching for embedding
visualizations.
"""

import logging
from typing import Any, Dict

import numpy as np

from ..utils.database import EmbeddingDatabase


[docs] class VisualizationService: """Handle visualization tasks like UMAP, PCA and t-SNE coordinate generation. Manages the computation and caching of projection coordinates for embedding visualizations. Attributes: db: The embedding database for caching coordinates. """
[docs] def __init__(self, db: EmbeddingDatabase) -> None: """Initialize the VisualizationService. Args: db: The embedding database for caching. """ self.db = db
[docs] def get_or_create_projections( self, embeddings: np.ndarray, base_key: str, file_id: str, similarities: np.ndarray, ) -> Dict[str, Any] | None: """Compute or retrieve projection coordinates (UMAP, t-SNE, PCA). Checks the database cache for existing coordinates using method-specific keys. Args: embeddings: Embedding matrix of shape (n_samples, n_dims). base_key: Base cache key for the computation. file_id: Identifier for the file being visualized. similarities: Similarity matrix for determining labels. threshold: Similarity threshold for labeling points. Returns: Dictionary containing projection data for umap, tsne, and pca. Returns None if embeddings have <= 1 sample or on error. """ if embeddings.shape[0] <= 1: return None similarity_scores = similarities.max(axis=0) safe_similarities = self._safe_convert_to_python_types(similarity_scores) methods = ["umap", "tsne", "pca"] results = {} for method in methods: cache_key = f"{base_key}_{method}" cached_data = self.db.get_projection_coordinates(cache_key, file_id) if cached_data is not None: coords = cached_data results[method] = { "x": self._safe_convert_to_python_types(cached_data["x"]), "y": self._safe_convert_to_python_types(cached_data["y"]), "similarities": safe_similarities, "title": f"Visualization for {file_id}", } for k, v in cached_data.items(): if k not in ["x", "y", "labels", "threshold"]: results[method][k] = v continue # Compute new coordinates try: if method == "umap": import umap n_neighbors = min(15, embeddings.shape[0] - 1) n_neighbors = max(2, n_neighbors) reducer = umap.UMAP( n_neighbors=n_neighbors, min_dist=0.1, n_components=2, random_state=42, metric='cosine' ) projections = reducer.fit_transform(embeddings) coords = { "x": projections[:, 0].astype(float).tolist(), "y": projections[:, 1].astype(float).tolist(), "n_neighbors": n_neighbors, "min_dist": 0.1, "metric": 'cosine' } elif method == "pca": from sklearn.decomposition import PCA pca = PCA(n_components=2, random_state=42) projections = pca.fit_transform(embeddings) variance_ratio = pca.explained_variance_ratio_ coords = { "x": projections[:, 0].astype(float).tolist(), "y": projections[:, 1].astype(float).tolist(), "explained_variance_1": float(variance_ratio[0]), "explained_variance_2": float(variance_ratio[1]) } elif method == "tsne": from sklearn.manifold import TSNE perplexity_val = min(30, embeddings.shape[0] - 1) perplexity_val = max(1, perplexity_val) tsne = TSNE( n_components=2, perplexity=perplexity_val, random_state=42, max_iter=1000, init="pca", learning_rate="auto", ) projections = tsne.fit_transform(embeddings) coords = { "x": projections[:, 0].astype(float).tolist(), "y": projections[:, 1].astype(float).tolist(), "kl_divergence": float(getattr(tsne, 'kl_divergence_', 0.0)), "n_iter": int(getattr(tsne, 'n_iter_', 0)), "perplexity": perplexity_val, "max_iter": 1000, "init": "pca", "learning_rate": "auto", } self.db.save_projection_coordinates(cache_key, file_id, coords) method_result = { "x": coords["x"], "y": coords["y"], "similarities": safe_similarities, "title": f"Visualization for {file_id}", } for k, v in coords.items(): if k not in ["x", "y"]: method_result[k] = v results[method] = method_result except Exception as e: logging.error(f"Error during {method} calculation for {file_id}: {e}") return results
def _safe_convert_to_python_types(self, data: Any) -> Any: """Recursively convert all NumPy types to native Python types.""" if isinstance(data, np.ndarray): return data.astype(float).tolist() elif isinstance(data, (np.floating, float)): return float(data) elif isinstance(data, (np.integer, int)): return int(data) elif isinstance(data, dict): return { key: self._safe_convert_to_python_types(value) for key, value in data.items() } elif isinstance(data, (list, tuple)): return [self._safe_convert_to_python_types(item) for item in data] else: return data