Source code for src.services.visualization_service

"""Visualization service for ForzaEmbed.

This module provides the VisualizationService class that handles
t-SNE coordinate generation and caching for embedding visualizations.

Example:
    Generate t-SNE visualization data::

        from src.services.visualization_service import VisualizationService

        service = VisualizationService(db)
        tsne_data = service.get_or_create_tsne_data(
            embeddings, "key", "file_id", similarities, 0.5
        )
"""

import logging
from typing import Any

import numpy as np

from ..utils.database import EmbeddingDatabase


[docs] class VisualizationService: """Handle visualization tasks like t-SNE coordinate generation. Manages the computation and caching of t-SNE coordinates for embedding visualizations. Attributes: db: The embedding database for caching t-SNE coordinates. """
[docs] def __init__(self, db: EmbeddingDatabase) -> None: """Initialize the VisualizationService. Args: db: The embedding database for caching. """ self.db = db
[docs] def get_or_create_tsne_data( self, embeddings: np.ndarray, tsne_key: str, file_id: str, similarities: np.ndarray, threshold: float, ) -> dict[str, Any] | None: """Compute or retrieve t-SNE coordinates for a given combination. Checks the database cache for existing t-SNE coordinates. If not found, computes new coordinates using sklearn's TSNE implementation. Args: embeddings: Embedding matrix of shape (n_samples, n_dims). tsne_key: Cache key for the t-SNE computation. file_id: Identifier for the file being visualized. similarities: Similarity matrix for determining labels. threshold: Similarity threshold for labeling points. Returns: Dictionary containing t-SNE visualization data with keys: - 'x': List of x-coordinates. - 'y': List of y-coordinates. - 'labels': List of threshold-based labels. - 'similarities': List of similarity scores. - 'title': Visualization title. - 'threshold': The threshold value used. Returns None if embeddings have <= 1 sample or on error. """ if embeddings.shape[0] <= 1: return None # Check if t-SNE coordinates already exist cached_tsne = self.db.get_tsne_coordinates(tsne_key, file_id) if cached_tsne is not None: # Use existing coordinates but recalculate labels based on new similarities similarity_scores = similarities.max(axis=0) scatter_labels = [ "Above Threshold" if s >= threshold else "Below Threshold" for s in similarity_scores ] # Ensure all data is native Python types tsne_data = { "x": self._safe_convert_to_python_types(cached_tsne["x"]), "y": self._safe_convert_to_python_types(cached_tsne["y"]), "labels": scatter_labels, "similarities": self._safe_convert_to_python_types(similarity_scores), "title": f"t-SNE Visualization for {file_id}", "threshold": float(threshold), } return tsne_data # Compute new t-SNE coordinates try: from sklearn.manifold import TSNE tsne = TSNE( n_components=2, perplexity=min(30, embeddings.shape[0] - 1), random_state=42, max_iter=300, ) tsne_results = tsne.fit_transform(embeddings) # Save coordinates for reuse tsne_coords = { "x": tsne_results[:, 0].astype(float).tolist(), "y": tsne_results[:, 1].astype(float).tolist(), } self.db.save_tsne_coordinates(tsne_key, file_id, tsne_coords) # Calculate labels based on current similarities similarity_scores = similarities.max(axis=0) scatter_labels = [ "Above Threshold" if s >= threshold else "Below Threshold" for s in similarity_scores ] return { "x": tsne_coords["x"], "y": tsne_coords["y"], "labels": scatter_labels, "similarities": self._safe_convert_to_python_types(similarity_scores), "title": f"t-SNE Visualization for {file_id}", "threshold": float(threshold), } except Exception as e: logging.error(f"Error during t-SNE calculation for {file_id}: {e}") return None
def _safe_convert_to_python_types(self, data: Any) -> Any: """Recursively convert all NumPy types to native Python types. Args: data: Data to convert, can be ndarray, scalar, dict, list, or other. Returns: Data with all NumPy types converted to native Python equivalents. """ if isinstance(data, np.ndarray): return data.astype(float).tolist() elif isinstance(data, (np.floating, float)): return float(data) elif isinstance(data, (np.integer, int)): return int(data) elif isinstance(data, dict): return { key: self._safe_convert_to_python_types(value) for key, value in data.items() } elif isinstance(data, (list, tuple)): return [self._safe_convert_to_python_types(item) for item in data] else: return data