Source code for src.reporting.aggregator

"""Data aggregation module for ForzaEmbed reporting.

This module provides the DataAggregator class that handles aggregation
and caching of processed data from the database for report generation.

Example:
    Aggregate data for reporting::

        from src.reporting.aggregator import DataAggregator

        aggregator = DataAggregator(db, output_dir, "config_name")
        data = aggregator.get_aggregated_data()
"""

import logging
from pathlib import Path
from typing import Any

import joblib
import numpy as np
from tqdm import tqdm

from ..utils.database import EmbeddingDatabase


[docs] class DataAggregator: """Handle aggregation and caching of processed data from the database. Aggregates processing results from the database into a format suitable for report generation, with caching to avoid redundant computation. Attributes: db: The embedding database containing results. output_dir: Directory path for cache files. cache_path: Path to the cache file. """
[docs] def __init__( self, db: EmbeddingDatabase, output_dir: Path, config_name: str ) -> None: """Initialize the DataAggregator. Args: db: The embedding database containing results. output_dir: Directory path for cache files. config_name: Name of the configuration for cache file prefix. """ self.db = db self.output_dir = output_dir self.cache_path = self.output_dir / f"{config_name}_reports_cache.joblib"
[docs] def get_aggregated_data(self) -> dict[str, Any] | None: """Load aggregated data from cache if valid, otherwise aggregate from scratch. Checks if the cache is newer than the database modification time. If valid, loads from cache; otherwise, aggregates fresh data. Returns: Dictionary containing aggregated data for reporting, or None if no processing results are available. Contains keys: - all_results: Raw results from database. - processed_data_for_interactive_page: Optimized web data. - all_models_metrics: Metrics organized by model. - model_embeddings_for_variance: Embeddings for analysis. - total_combinations: Count of model combinations. """ db_mod_time = self.db.get_db_modification_time() use_cache = ( self.cache_path.exists() and self.cache_path.stat().st_mtime > db_mod_time ) if use_cache: logging.info(f"Loading aggregated data from cache: {self.cache_path}") return joblib.load(self.cache_path) logging.info("No valid cache found. Aggregating data from scratch...") all_results = self.db.get_all_processing_results() if not all_results: logging.warning("No processing results found in the database.") return None aggregated_data = self._aggregate_data(all_results) joblib.dump(aggregated_data, self.cache_path) logging.info(f"Saved aggregated data to cache: {self.cache_path}") return aggregated_data
def _aggregate_data(self, all_results: dict[str, Any]) -> dict[str, Any]: """Aggregate data from results for reporting. Args: all_results: Dictionary of processing results from database. Returns: Dictionary containing aggregated and processed data. """ processed_data_for_interactive_page: dict[str, Any] = {"files": {}} all_models_metrics: dict[str, list[dict[str, Any]]] = {} model_embeddings_for_variance: dict[str, dict[str, Any]] = {} # First, normalize dot product scores globally if needed all_results = self._normalize_dot_product_globally(all_results) for model_name, model_results in tqdm( all_results.items(), desc="Aggregating data for reports" ): # Aggregate embeddings and labels from all files for this model aggregated_embeddings: list[Any] = [] aggregated_labels: list[Any] = [] for file_id, file_data in model_results.get("files", {}).items(): if "embeddings" in file_data and file_data["embeddings"] is not None: aggregated_embeddings.extend(file_data["embeddings"]) if "labels" in file_data and file_data["labels"] is not None: aggregated_labels.extend(file_data["labels"]) model_embeddings_for_variance[model_name] = { "embeddings": np.array(aggregated_embeddings) if aggregated_embeddings else np.array([]), "labels": aggregated_labels, } # Prepare data for the interactive page for file_id, file_data in model_results.get("files", {}).items(): file_name = file_data.get("file_name", file_id) file_entry = processed_data_for_interactive_page["files"].setdefault( file_id, {"fileName": file_name, "embeddings": {}} ) file_entry["embeddings"][model_name] = { "phrases": file_data.get("phrases", []), "similarities": file_data.get("similarities", []), "metrics": file_data.get("metrics", {}), "scatter_plot_data": file_data.get("scatter_plot_data"), } # Store detailed metrics for each file detailed_metrics: list[dict[str, Any]] = [] for file_id, file_data in model_results.get("files", {}).items(): if "metrics" in file_data and file_data["metrics"]: metric_record: dict[str, Any] = {"file_name": file_id} metric_record.update(file_data["metrics"]) detailed_metrics.append(metric_record) all_models_metrics[model_name] = detailed_metrics optimized_data = self._optimize_data_for_web( processed_data_for_interactive_page ) return { "all_results": all_results, "processed_data_for_interactive_page": optimized_data, "all_models_metrics": all_models_metrics, "model_embeddings_for_variance": model_embeddings_for_variance, "total_combinations": len(all_results), } def _optimize_data_for_web(self, data: dict[str, Any]) -> dict[str, Any]: """Optimize the data structure for web output by rounding floats. Args: data: Data dictionary to optimize. Returns: Optimized data dictionary with floats rounded to 4 decimal places. """ from typing import cast def round_floats(obj: Any) -> Any: if isinstance(obj, list): return [round_floats(v) for v in obj] if isinstance(obj, dict): return {k: round_floats(v) for k, v in obj.items()} if isinstance(obj, float): return round(obj, 4) return obj return cast(dict[str, Any], round_floats(data)) def _normalize_dot_product_globally( self, all_results: dict[str, Any] ) -> dict[str, Any]: """Perform global min-max scaling on dot product similarities. Args: all_results: Dictionary of processing results. Returns: Results with dot product similarities normalized to [0, 1]. """ for model_name, model_results in all_results.items(): model_info = self.db.get_model_info(model_name) if not model_info or model_info.get("similarity_metric") != "dot_product": continue all_similarities: list[float] = [ sim for file_data in model_results.get("files", {}).values() if file_data.get("similarities") is not None for sim in file_data["similarities"] ] if not all_similarities: continue dot_min, dot_max = min(all_similarities), max(all_similarities) denominator = dot_max - dot_min if denominator == 0: for file_data in model_results.get("files", {}).values(): if "similarities" in file_data: file_data["similarities"] = [0.5] * len( file_data["similarities"] ) continue for file_data in model_results.get("files", {}).values(): if "similarities" in file_data: file_data["similarities"] = [ (s - dot_min) / denominator for s in file_data["similarities"] ] return all_results
[docs] def touch_cache(self) -> None: """Update the cache file's modification time to the current time.""" if self.cache_path.exists(): self.cache_path.touch() logging.info(f"Updated cache timestamp: {self.cache_path}")