"""Database management module for ForzaEmbed.
This module provides the EmbeddingDatabase class for managing all database
operations including storing embeddings, results, and metadata. It implements
intelligent quantization for efficient storage and caching mechanisms for
improved performance.
Example:
Basic database usage::
from src.utils.database import EmbeddingDatabase
db = EmbeddingDatabase("results.db", config)
db.save_embeddings_batch("model_name", embeddings_dict)
cached = db.get_embeddings_by_hashes("model_name", ["hash1", "hash2"])
"""
import logging
import os
import zlib
from typing import Any, Dict, List, Optional, Tuple, Union
import msgpack
import numpy as np
from sqlalchemy import create_engine, delete, select, text, update
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
from sqlalchemy.orm import sessionmaker
from ..core.config import AppConfig
from .models import (
Base,
EmbeddingCache,
EvaluationMetric,
GeneratedFile,
GlobalChart,
Model,
ProcessingResult,
TSNECoordinate,
)
[docs]
class EmbeddingDatabase:
"""Manage SQLite database for embeddings, results, and metadata.
This class handles all database operations for ForzaEmbed, including
storing and retrieving embeddings, processing results, and various
metadata. Implements intelligent quantization to reduce storage size.
Attributes:
db_path: Path to the SQLite database file.
config: Application configuration (dict or AppConfig).
quantization_enabled: Whether intelligent quantization is enabled.
engine: SQLAlchemy database engine.
Session: SQLAlchemy session factory.
"""
[docs]
def __init__(
self, db_path: str, config: Union[AppConfig, Dict[str, Any]]
) -> None:
"""Initialize the EmbeddingDatabase.
Args:
db_path: Path to the SQLite database file.
config: Application configuration, either as AppConfig or dict.
"""
self.db_path = db_path
self.config = config
if isinstance(config, dict):
self.quantization_enabled: bool = config.get("database", {}).get(
"intelligent_quantization", True
)
else:
self.quantization_enabled = config.database.intelligent_quantization
db_dir = os.path.dirname(db_path)
if db_dir:
os.makedirs(db_dir, exist_ok=True)
# Initialize SQLAlchemy engine and session
self.engine = create_engine(f"sqlite:///{self.db_path}")
self.Session = sessionmaker(bind=self.engine)
self._init_database()
def _init_database(self) -> None:
"""Initialize the database tables."""
Base.metadata.create_all(self.engine)
[docs]
def add_model(
self,
name: str,
base_model_name: str,
model_type: str,
chunk_size: int,
chunk_overlap: int,
theme_name: str,
chunking_strategy: str,
similarity_metric: str,
) -> None:
"""Add a model run to the database.
Args:
name: Unique run name identifier.
base_model_name: The underlying model name.
model_type: Type of model (api, fastembed, etc.).
chunk_size: Chunk size used.
chunk_overlap: Chunk overlap used.
theme_name: Theme set name.
chunking_strategy: Chunking strategy used.
similarity_metric: Similarity metric used.
"""
with self.Session() as session:
stmt = sqlite_insert(Model).values(
name=name,
base_model_name=base_model_name,
type=model_type,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
theme_name=theme_name,
chunking_strategy=chunking_strategy,
similarity_metric=similarity_metric,
).on_conflict_do_nothing()
session.execute(stmt)
session.commit()
[docs]
def add_evaluation_metrics(
self, model_name: str, metrics: Dict[str, float]
) -> None:
"""Add or update evaluation metrics for a model.
Args:
model_name: The model run name.
metrics: Dictionary of metric names to values.
"""
with self.Session() as session:
# Delete existing metrics for this model
session.execute(
delete(EvaluationMetric).where(EvaluationMetric.model_name == model_name)
)
# Insert new metrics
metric = EvaluationMetric(
model_name=model_name,
silhouette_score=metrics.get("silhouette_score"),
intra_cluster_distance_normalized=metrics.get("intra_cluster_distance_normalized"),
inter_cluster_distance_normalized=metrics.get("inter_cluster_distance_normalized"),
embedding_computation_time=metrics.get("embedding_computation_time"),
)
session.add(metric)
session.commit()
[docs]
def add_generated_file(
self, model_name: str, file_type: str, file_path: str
) -> None:
"""Add a generated file record to the database.
Args:
model_name: The model run name.
file_type: Type of the generated file.
file_path: Path to the generated file.
"""
with self.Session() as session:
file_entry = GeneratedFile(
model_name=model_name,
file_type=file_type,
file_path=file_path
)
session.add(file_entry)
session.commit()
[docs]
def add_global_chart(self, chart_type: str, file_path: str) -> None:
"""Add or update a global chart record.
Args:
chart_type: Type identifier for the chart.
file_path: Path to the chart image file.
"""
with self.Session() as session:
# Delete existing chart of this type
session.execute(
delete(GlobalChart).where(GlobalChart.chart_type == chart_type)
)
# Insert new chart
chart = GlobalChart(chart_type=chart_type, file_path=file_path)
session.add(chart)
session.commit()
[docs]
def model_exists(self, name: str) -> bool:
"""Check if a model with the specified run name exists.
Args:
name: The run name to check.
Returns:
True if the model exists, False otherwise.
"""
with self.Session() as session:
result = session.execute(
select(Model).where(Model.name == name)
).first()
return result is not None
[docs]
def save_processing_result(
self, model_name: str, file_id: str, results: Dict[str, Any]
) -> None:
"""Save detailed processing result for a file and model.
Args:
model_name: The model run name.
file_id: Identifier for the processed file.
results: Dictionary of processing results.
"""
# Apply intelligent quantization before serialization
quantized_results = self._apply_intelligent_quantization(results)
results_blob = msgpack.packb(
quantized_results, default=self._numpy_default, use_bin_type=True
)
with self.Session() as session:
stmt = sqlite_insert(ProcessingResult).values(
model_name=model_name,
file_id=file_id,
results_blob=results_blob
)
# Upsert: replace if exists
stmt = stmt.on_conflict_do_update(
index_elements=['model_name', 'file_id'],
set_=dict(results_blob=results_blob)
)
session.execute(stmt)
session.commit()
[docs]
def save_processing_results_batch(
self, results_batch: List[Tuple[str, str, Dict[str, Any]]]
) -> None:
"""Save a batch of processing results in a single transaction.
Args:
results_batch: List of (model_name, file_id, results) tuples.
"""
items_to_insert = []
for model_name, file_id, results in results_batch:
# Apply intelligent quantization before serialization
quantized_results = self._apply_intelligent_quantization(results)
results_blob = msgpack.packb(
quantized_results, default=self._numpy_default, use_bin_type=True
)
items_to_insert.append({
"model_name": model_name,
"file_id": file_id,
"results_blob": results_blob
})
if not items_to_insert:
return
with self.Session() as session:
stmt = sqlite_insert(ProcessingResult).values(items_to_insert)
stmt = stmt.on_conflict_do_update(
index_elements=['model_name', 'file_id'],
set_=dict(results_blob=stmt.excluded.results_blob)
)
session.execute(stmt)
session.commit()
[docs]
def get_processed_files(self, model_name: str) -> List[str]:
"""Retrieve file IDs that have been processed for a model.
Args:
model_name: The model run name.
Returns:
List of file IDs that have been processed.
"""
with self.Session() as session:
result = session.execute(
select(ProcessingResult.file_id).where(ProcessingResult.model_name == model_name)
)
return [row[0] for row in result.fetchall()]
[docs]
def get_model_info(self, run_name: str) -> Optional[Dict[str, Any]]:
"""Retrieve information about a model by its run name.
Args:
run_name: The unique run name identifier.
Returns:
Dictionary with model information, or None if not found.
"""
with self.Session() as session:
model = session.execute(
select(Model).where(Model.name == run_name)
).scalar_one_or_none()
if model:
return {
"model_name": model.name,
"type": model.type,
"chunk_size": model.chunk_size,
"chunk_overlap": model.chunk_overlap,
"theme_name": model.theme_name,
"chunking_strategy": model.chunking_strategy,
}
return None
[docs]
def get_all_processing_results(self) -> Dict[str, Any]:
"""Retrieve all processing results organized by model run name.
Fetches raw file-level results without model-level aggregation.
Returns:
Dictionary mapping model names to their file results.
"""
all_results: Dict[str, Dict[str, Any]] = {}
with self.Session() as session:
# Get all unique model names first
model_names = session.execute(
select(ProcessingResult.model_name).distinct()
).scalars().all()
for model_name in model_names:
# Get all file results for the current model
file_results = self.get_all_processing_results_for_run(model_name)
all_results[model_name] = {"files": file_results}
return all_results
[docs]
def get_all_models(self) -> List[Dict[str, Any]]:
"""Retrieve all models with their metrics.
Returns:
List of dictionaries containing model information and metrics.
"""
with self.Session() as session:
# Left join Model and EvaluationMetric
stmt = select(Model, EvaluationMetric).outerjoin(
EvaluationMetric, Model.name == EvaluationMetric.model_name
).order_by(Model.name)
results = session.execute(stmt).all()
models = []
for model, metric in results:
models.append(
{
"name": model.name,
"base_model_name": model.base_model_name,
"type": model.type,
"chunk_size": model.chunk_size,
"chunk_overlap": model.chunk_overlap,
"theme_name": model.theme_name,
"chunking_strategy": model.chunking_strategy,
"silhouette_score": metric.silhouette_score if metric else None,
"intra_cluster_distance_normalized": metric.intra_cluster_distance_normalized if metric else None,
"inter_cluster_distance_normalized": metric.inter_cluster_distance_normalized if metric else None,
"embedding_computation_time": metric.embedding_computation_time if metric else None,
}
)
return models
[docs]
def get_model_files(self, model_name: str) -> List[Dict[str, str]]:
"""Retrieve all generated files for a model.
Args:
model_name: The model run name.
Returns:
List of dictionaries with file type and path.
"""
with self.Session() as session:
files = session.execute(
select(GeneratedFile).where(GeneratedFile.model_name == model_name).order_by(GeneratedFile.file_type)
).scalars().all()
return [{"type": f.file_type, "path": f.file_path} for f in files]
[docs]
def get_global_charts(self) -> List[Dict[str, str]]:
"""Retrieve all global charts.
Returns:
List of dictionaries with chart type and path.
"""
with self.Session() as session:
charts = session.execute(
select(GlobalChart).order_by(GlobalChart.chart_type)
).scalars().all()
return [{"type": c.chart_type, "path": c.file_path} for c in charts]
[docs]
def vacuum_database(self) -> None:
"""Vacuum the database to reclaim space."""
with self.Session() as session:
session.execute(text("VACUUM"))
[docs]
def get_all_run_names(self) -> list[str]:
"""Retrieve all existing run names.
Returns:
List of unique run name identifiers.
"""
with self.Session() as session:
return session.execute(
select(Model.name).distinct()
).scalars().all()
[docs]
def get_processed_files_with_similarities(self, run_name: str) -> list[str]:
"""Retrieve files that have been processed with similarity scores.
Args:
run_name: The model run name.
Returns:
List of file IDs that have similarity data.
"""
processed_files = []
with self.Session() as session:
results = session.execute(
select(ProcessingResult.file_id, ProcessingResult.results_blob)
.where(ProcessingResult.model_name == run_name)
).all()
for file_id, results_blob in results:
results_data = msgpack.unpackb(
results_blob, object_hook=self._decode_numpy, raw=False
)
if "similarities" in results_data and results_data["similarities"] is not None:
processed_files.append(file_id)
return processed_files
[docs]
def get_embeddings_by_hashes(
self, base_model_name: str, text_hashes: List[str]
) -> Dict[str, np.ndarray]:
"""Retrieve embeddings from cache by model and text hashes.
Args:
base_model_name: The base model name used for embeddings.
text_hashes: List of text hash values to retrieve.
Returns:
Dictionary mapping text hashes to embedding arrays.
"""
if not text_hashes:
return {}
with self.Session() as session:
# SQLAlchemy IN clause
results = session.execute(
select(EmbeddingCache.text_hash, EmbeddingCache.vector)
.where(EmbeddingCache.model_name == base_model_name)
.where(EmbeddingCache.text_hash.in_(text_hashes))
).all()
embeddings = {}
for text_hash, vector_blob in results:
embeddings[text_hash] = msgpack.unpackb(
vector_blob, object_hook=self._decode_numpy, raw=False
)
return embeddings
[docs]
def save_embeddings_batch(
self, base_model_name: str, embeddings: Dict[str, np.ndarray]
) -> None:
"""Save a batch of embeddings to the cache.
Applies intelligent quantization to reduce storage size when enabled.
Args:
base_model_name: The base model name for the embeddings.
embeddings: Dictionary mapping text hashes to embedding arrays.
"""
if not embeddings:
return
# Flag for maximum compression in cache
self._cache_storage = True
items_to_insert = []
for text_hash, vector in embeddings.items():
# Apply intelligent quantization to embeddings for cache storage
if self.quantization_enabled and vector.dtype.kind == "f":
# Most embeddings are normalized [-1,1] → float16 is sufficient
if -1.1 <= vector.min() and vector.max() <= 1.1:
vector = vector.astype(np.float16)
elif vector.dtype == np.float64:
# Downcast float64 → float32
vector = vector.astype(np.float32)
vector_blob = msgpack.packb(
vector, default=self._numpy_default, use_bin_type=True
)
dimension = len(vector) if len(vector.shape) == 1 else vector.shape[1]
items_to_insert.append({
"model_name": base_model_name,
"text_hash": text_hash,
"vector": vector_blob,
"dimension": dimension
})
if hasattr(self, "_cache_storage"):
delattr(self, "_cache_storage")
with self.Session() as session:
stmt = sqlite_insert(EmbeddingCache).values(items_to_insert).on_conflict_do_nothing()
session.execute(stmt)
session.commit()
[docs]
def save_tsne_coordinates(
self, tsne_key: str, file_id: str, coordinates: Dict[str, List[float]]
) -> None:
"""Save t-SNE coordinates for a given configuration.
Args:
tsne_key: Unique key for the t-SNE configuration.
file_id: Identifier for the file.
coordinates: Dictionary with 'x' and 'y' coordinate lists.
"""
coordinates_blob = msgpack.packb(coordinates, use_bin_type=True)
with self.Session() as session:
stmt = sqlite_insert(TSNECoordinate).values(
tsne_key=tsne_key,
file_id=file_id,
coordinates=coordinates_blob
)
stmt = stmt.on_conflict_do_update(
index_elements=['tsne_key', 'file_id'],
set_=dict(coordinates=coordinates_blob)
)
session.execute(stmt)
session.commit()
[docs]
def get_tsne_coordinates(
self, tsne_key: str, file_id: str
) -> Optional[Dict[str, List[float]]]:
"""Retrieve t-SNE coordinates for a given configuration.
Args:
tsne_key: Unique key for the t-SNE configuration.
file_id: Identifier for the file.
Returns:
Dictionary with 'x' and 'y' coordinate lists, or None if not found.
"""
with self.Session() as session:
result = session.execute(
select(TSNECoordinate.coordinates)
.where(TSNECoordinate.tsne_key == tsne_key)
.where(TSNECoordinate.file_id == file_id)
).scalar_one_or_none()
if result:
return msgpack.unpackb(result, raw=False)
return None
[docs]
def clear_tsne_cache(self) -> None:
"""Clear all cached t-SNE coordinates."""
with self.Session() as session:
session.execute(delete(TSNECoordinate))
session.commit()
[docs]
def get_run_details(self, run_name: str) -> Optional[Dict[str, Any]]:
"""Retrieve detailed information for a specific run.
Args:
run_name: The unique run name identifier.
Returns:
Dictionary with full run details, or None if not found.
"""
with self.Session() as session:
model = session.execute(
select(Model).where(Model.name == run_name)
).scalar_one_or_none()
if model:
return {
"id": model.id,
"name": model.name,
"base_model_name": model.base_model_name,
"type": model.type,
"chunk_size": model.chunk_size,
"chunk_overlap": model.chunk_overlap,
"theme_name": model.theme_name,
"chunking_strategy": model.chunking_strategy,
"similarity_metric": model.similarity_metric,
"created_at": model.created_at
}
return None
[docs]
def get_all_processing_results_for_run(
self, model_name: str
) -> Dict[str, Dict[str, Any]]:
"""Get all processing results for a specific run with dequantization.
Args:
model_name: The model run name.
Returns:
Dictionary mapping file IDs to their processing results.
"""
results = {}
with self.Session() as session:
query_results = session.execute(
select(ProcessingResult.file_id, ProcessingResult.results_blob)
.where(ProcessingResult.model_name == model_name)
).all()
for file_id, results_blob in query_results:
data = msgpack.unpackb(
results_blob, object_hook=self._decode_numpy, raw=False
)
# Restore quantized data
restored_data = self._restore_quantized_data(data)
# Convert all numpy types to native Python types for JSON serialization
results[file_id] = self._to_native_python_types(restored_data)
return results
def _to_native_python_types(self, obj: Any) -> Any:
"""Recursively convert NumPy types to native Python types.
Args:
obj: Object potentially containing NumPy types.
Returns:
Object with all NumPy types converted to Python native types.
"""
if isinstance(obj, dict):
return {k: self._to_native_python_types(v) for k, v in obj.items()}
if isinstance(obj, list):
return [self._to_native_python_types(i) for i in obj]
if isinstance(obj, np.ndarray):
return obj.tolist()
# Use abstract base classes for broad compatibility (including NumPy 2.0)
if isinstance(obj, np.integer):
return int(obj)
if isinstance(obj, np.floating):
return float(obj)
if isinstance(obj, np.bool_):
return bool(obj)
return obj
def _dequantize_similarities(
self, similarities: np.ndarray | list[float]
) -> list[float]:
"""Dequantize similarities to [0,1] range.
Args:
similarities: Quantized or raw similarity values.
Returns:
List of dequantized similarity values in [0,1] range.
"""
if isinstance(similarities, list):
similarities = np.array(similarities)
# If values are uint16 (quantized), convert back to float [0,1]
if similarities.dtype == np.uint16:
similarities = similarities.astype(np.float32) / 65535.0
elif similarities.dtype in [np.int32, np.int64]:
# Handle case where uint16 was read as int
similarities = np.clip(similarities, 0, 65535).astype(np.float32) / 65535.0
# Ensure values are in [0,1] range
similarities = np.clip(similarities, 0.0, 1.0)
return similarities.tolist()
[docs]
def update_metrics_for_file(
self, model_name: str, file_id: str, metrics: Dict[str, Any]
) -> None:
"""Update metrics for a specific file in a model run.
Args:
model_name: The model run name.
file_id: Identifier for the file.
metrics: Dictionary of metric values to update.
"""
with self.Session() as session:
# First, retrieve the existing blob
result = session.execute(
select(ProcessingResult.results_blob)
.where(ProcessingResult.model_name == model_name)
.where(ProcessingResult.file_id == file_id)
).scalar_one_or_none()
if not result:
return # Or handle error
# Deserialize, update, and re-serialize
results_data = msgpack.unpackb(
result, object_hook=self._decode_numpy, raw=False
)
# Restore quantized data before update
results_data = self._restore_quantized_data(results_data)
results_data["metrics"].update(metrics)
# Apply quantization and serialize
quantized_results = self._apply_intelligent_quantization(results_data)
updated_blob = msgpack.packb(
quantized_results, default=self._numpy_default, use_bin_type=True
)
# Update the blob in the database
stmt = update(ProcessingResult).where(
ProcessingResult.model_name == model_name,
ProcessingResult.file_id == file_id
).values(results_blob=updated_blob)
session.execute(stmt)
session.commit()
def _quantize_metrics(self, metrics: Dict[str, float]) -> Dict[str, Any]:
"""Quantize metrics based on their expected ranges.
Applies intelligent quantization to reduce storage size while
maintaining precision for the metric's expected range.
Args:
metrics: Dictionary of metric names to values.
Returns:
Dictionary with quantized metric values.
"""
if not self.quantization_enabled:
return metrics
quantized = {}
for key, value in metrics.items():
if value is None:
quantized[key] = None
continue
# Similarity-based metrics [0,1] → uint16
if any(
keyword in key.lower()
for keyword in ["similarity", "coherence", "density", "robustness"]
):
if "coherence" in key.lower():
# Internal coherence: typically [0, 1], but invert for storage
quantized[key] = np.uint16(
np.clip((1.0 - float(value)) * 65535, 0, 65535)
)
else:
quantized[key] = np.uint16(np.clip(float(value) * 65535, 0, 65535))
# Silhouette score [-1, 1] → uint16 with offset
elif "silhouette" in key.lower():
# Map [-1, 1] to [0, 65535]
normalized = (float(value) + 1.0) / 2.0
quantized[key] = np.uint16(np.clip(normalized * 65535, 0, 65535))
# Distance metrics → float32
elif "distance" in key.lower():
if "normalized" in key.lower():
# Normalized distances [0,1] → uint16
quantized[key] = np.uint16(np.clip(float(value) * 65535, 0, 65535))
else:
quantized[key] = np.float32(value)
else:
# Default: keep as float32
quantized[key] = np.float32(value)
return quantized
def _dequantize_metrics(self, metrics: Dict[str, Any]) -> Dict[str, float | None]:
"""Dequantize metrics back to their original ranges.
Reverses the quantization process to restore original metric values.
Args:
metrics: Dictionary of quantized metric values.
Returns:
Dictionary with dequantized float values.
"""
if not self.quantization_enabled:
return {k: float(v) if v is not None else None for k, v in metrics.items()}
dequantized: Dict[str, float | None] = {}
for key, value in metrics.items():
if value is None:
dequantized[key] = None
continue
# Handle quantized uint16 values
if (isinstance(value, (np.integer, int))) and value <= 65535:
if "coherence" in key.lower():
# Internal coherence was inverted: restore original
dequantized[key] = 1.0 - (float(value) / 65535.0)
elif "silhouette" in key.lower():
# Silhouette: map [0, 65535] back to [-1, 1]
normalized = float(value) / 65535.0
dequantized[key] = (normalized * 2.0) - 1.0
elif any(
keyword in key.lower()
for keyword in ["similarity", "density", "robustness", "normalized"]
):
# Standard [0,1] metrics
dequantized[key] = float(value) / 65535.0
else:
dequantized[key] = float(value)
else:
dequantized[key] = float(value)
return dequantized
[docs]
def get_db_modification_time(self) -> float:
"""Get the last modification time of the database file.
Returns:
Unix timestamp of the last modification.
"""
return os.path.getmtime(self.db_path)
def _apply_intelligent_quantization(self, obj: Any) -> Any:
"""Apply intelligent quantization based on data type and content.
Optimizes storage by quantizing data based on its type and expected
value ranges.
Args:
obj: Object to quantize.
Returns:
Quantized object with reduced precision where appropriate.
"""
if not self.quantization_enabled:
return obj
if isinstance(obj, dict):
quantized = {}
for key, value in obj.items():
if key in [
"similarities",
"cosine_similarity",
"dot_product",
] and isinstance(value, (np.ndarray, list)):
# Similarity metrics - quantifier seulement si dans [0,1]
if isinstance(value, list):
value = np.array(value)
if (
value.dtype.kind == "f"
and 0 <= value.min()
and value.max()
<= 1.01 # Petite tolérance pour les erreurs de floating point
):
quantized[key] = (value * 65535).astype(np.uint16)
else:
# Garder en float32 pour dot_product et autres métriques non-normalisées
quantized[key] = value.astype(np.float32)
elif key == "scatter_plot_data" and isinstance(value, dict):
# 2D coordinates → float16
scatter_quantized = {}
for scatter_key, scatter_value in value.items():
if scatter_key in ["x", "y"] and isinstance(
scatter_value, (list, np.ndarray)
):
scatter_quantized[scatter_key] = np.array(
scatter_value, dtype=np.float16
)
elif scatter_key == "similarities" and isinstance(
scatter_value, (list, np.ndarray)
):
# Similarities in scatter plot data
if isinstance(scatter_value, list):
scatter_value = np.array(scatter_value)
if (
scatter_value.dtype.kind == "f"
and 0 <= scatter_value.min()
and scatter_value.max() <= 1
):
scatter_quantized[scatter_key] = (
scatter_value * 65535
).astype(np.uint16)
else:
scatter_quantized[scatter_key] = scatter_value
else:
scatter_quantized[scatter_key] = scatter_value
quantized[key] = scatter_quantized
elif key == "metrics" and isinstance(value, dict):
# Apply metric-specific quantization
quantized[key] = self._quantize_metrics(value)
else:
quantized[key] = self._apply_intelligent_quantization(value)
return quantized
elif isinstance(obj, np.ndarray):
if obj.dtype == np.float64:
# Downcast float64 → float32
return obj.astype(np.float32)
elif obj.dtype.kind == "f" and obj.ndim > 1:
# Embeddings: if normalized [-1,1], use float16
if -1.1 <= obj.min() and obj.max() <= 1.1:
return obj.astype(np.float16)
elif isinstance(obj, list):
return [self._apply_intelligent_quantization(item) for item in obj]
return obj
def _restore_quantized_data(self, obj: Any) -> Any:
"""Restore quantized data to original format.
Reverses the quantization process to restore full precision.
Args:
obj: Quantized object to restore.
Returns:
Object with restored full-precision values.
"""
if not self.quantization_enabled:
return obj
if isinstance(obj, dict):
restored = {}
for key, value in obj.items():
if key in [
"similarities",
"cosine_similarity",
"dot_product",
] and isinstance(value, np.ndarray):
# Restore uint16 → float32 seulement si c'était quantifié
if value.dtype == np.uint16:
restored[key] = value.astype(np.float32) / 65535.0
else:
# Était déjà en float32 (dot_product, etc.)
restored[key] = value.astype(np.float32)
elif key == "scatter_plot_data" and isinstance(value, dict):
scatter_restored = {}
for scatter_key, scatter_value in value.items():
if scatter_key == "similarities" and isinstance(
scatter_value, np.ndarray
):
if scatter_value.dtype == np.uint16:
scatter_restored[scatter_key] = (
scatter_value.astype(np.float32) / 65535.0
)
else:
scatter_restored[scatter_key] = scatter_value
else:
scatter_restored[scatter_key] = scatter_value
restored[key] = scatter_restored
elif key == "metrics" and isinstance(value, dict):
# Apply metric-specific dequantization
restored[key] = self._dequantize_metrics(value)
else:
restored[key] = self._restore_quantized_data(value)
return restored
elif isinstance(obj, list):
return [self._restore_quantized_data(item) for item in obj]
return obj
def _numpy_default(self, obj: Any) -> Any:
"""Custom encoder for numpy data types for msgpack.
Handles numpy arrays by compressing them with zlib.
Args:
obj: Object to encode.
Returns:
Encoded representation suitable for msgpack.
Raises:
TypeError: If the object is not serializable.
"""
if isinstance(obj, np.integer):
return int(obj)
if isinstance(obj, np.floating):
return float(obj)
if isinstance(obj, np.ndarray):
# Use maximum compression for cache storage (embeddings don't change often)
compression_level = 9 if hasattr(self, "_cache_storage") else 6
return {
"__ndarray__": True,
"dtype": obj.dtype.str,
"shape": obj.shape,
"data": zlib.compress(obj.tobytes(), level=compression_level),
}
raise TypeError(f"Object of type {obj.__class__.__name__} is not serializable")
def _decode_numpy(self, obj: Any) -> Any:
"""Custom decoder for numpy data types for msgpack.
Decompresses and reconstructs numpy arrays from msgpack format.
Args:
obj: Object to decode.
Returns:
Decoded object, with numpy arrays reconstructed.
"""
if isinstance(obj, dict) and "__ndarray__" in obj:
data = zlib.decompress(obj["data"])
return np.frombuffer(data, dtype=np.dtype(obj["dtype"])).reshape(
obj["shape"]
)
return obj