File size: 6,689 Bytes
e5bfc68 60532a1 e5bfc68 60532a1 e5bfc68 c9b82b3 e5bfc68 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Dict, Tuple, Optional
import chromadb
import numpy as np
from pydantic import BaseModel
from dataclasses import dataclass
from knowlang.chat_bot.chat_graph import ChatResult
from knowlang.configs.config import AppConfig, EmbeddingConfig
import json
from knowlang.evaluation.chatbot_evaluation import EvalCase, TRANSFORMER_TEST_CASES
from knowlang.models.embeddings import EmbeddingInputType, generate_embedding, EmbeddingVector
@dataclass
class ConfigEvalResult:
"""Results for a single configuration"""
config_name: str
distances: List[float]
stats: Dict[str, float]
class RetrievalMetrics(BaseModel):
"""Metrics for retrieval analysis"""
distances: List[float]
similarity_scores: List[float]
chunk_count: int
class EnhancedChatResult(ChatResult):
"""Extended ChatResult with retrieval metrics"""
retrieval_metrics: Optional[RetrievalMetrics] = None
def _embedding_cache_path(embdding_config: EmbeddingConfig) -> Path:
"""Get the cache path for a specific embedding configuration"""
return Path(f"embeddings_{embdding_config.model_name}_{embdding_config.model_provider.value}.json")
def load_cached_embeddings(cache_path: Path, embdding_config: EmbeddingConfig) -> Optional[Dict[str, EmbeddingVector]]:
"""Load cached embeddings for a specific configuration"""
cache_file = cache_path / _embedding_cache_path(embdding_config)
if cache_file.exists():
with open(cache_file, 'r') as f:
return json.load(f)
return None
def save_cached_embeddings(cache_path: Path, config: EmbeddingConfig, embeddings: Dict[str, EmbeddingVector]):
"""Save embeddings to cache"""
cache_path.mkdir(parents=True, exist_ok=True)
cache_file = cache_path / _embedding_cache_path(config)
with open(cache_file, 'w') as f:
json.dump(embeddings, f)
async def analyze_embedding_distributions(
test_cases: List[EvalCase],
configs: List[Tuple[str, AppConfig]],
cache_path : Path
) -> List[ConfigEvalResult]:
"""Analyze embedding distance distributions for multiple configurations"""
results = []
for config_name, config in configs:
# Try to load cached embeddings first
cached_embeddings = load_cached_embeddings(cache_path, config.embedding)
if cached_embeddings is None:
print(f"Generating new embeddings for {config_name}...")
# Generate embeddings for all test cases
questions = [case.question for case in test_cases]
try:
embeddings = generate_embedding(questions, config.embedding, input_type=EmbeddingInputType.QUERY)
# Cache the embeddings
cached_embeddings = {
question: embedding
for question, embedding in zip(questions, embeddings)
}
save_cached_embeddings(cache_path, config.embedding, cached_embeddings)
except Exception as e:
print(f"Error generating embeddings for {config_name}: {str(e)}")
continue
else:
print(f"Using cached embeddings for {config_name}")
# Get collection for this config
collection = chromadb.PersistentClient(
path=str(config.db.persist_directory)
).get_collection(f"{config.db.collection_name}")
distances = []
# Query each test case
for case in test_cases:
query_results = collection.query(
query_embeddings=[cached_embeddings[case.question]],
n_results=10,
include=['distances']
)
distances.extend(query_results['distances'][0])
# Calculate statistics
stats = {
'mean': float(np.mean(distances)),
'std': float(np.std(distances)),
'median': float(np.median(distances)),
'min': float(np.min(distances)),
'max': float(np.max(distances))
}
results.append(ConfigEvalResult(
config_name=config_name,
distances=distances,
stats=stats
))
return results
def plot_distance_distributions(results: List[ConfigEvalResult]):
"""Plot and compare distance distributions for all configurations"""
plt.figure(figsize=(15, 8))
# Convert data to pandas DataFrame
import pandas as pd
all_data = pd.DataFrame([
{'method': result.config_name, 'distance': d}
for result in results
for d in result.distances
])
# Distribution plot
plt.subplot(2, 2, 1)
sns.histplot(
data=all_data,
x='distance',
hue='method',
stat='density',
common_norm=True,
alpha=0.6
)
plt.title('Distance Distribution Comparison')
plt.xlabel('Cosine Distance')
plt.ylabel('Density')
# Box plot with pandas DataFrame
plt.subplot(2, 2, 2)
sns.boxplot(
data=all_data,
x='method',
y='distance'
)
plt.title('Distance Distribution Statistics')
plt.ylabel('Cosine Distance')
plt.xticks(rotation=45)
# Statistics summary
plt.subplot(2, 2, (3, 4))
stats_text = "Distance Statistics:\n\n"
for result in results:
stats_text += f"{result.config_name}:\n"
stats_text += f" Mean: {result.stats['mean']:.3f} ± {result.stats['std']:.3f}\n"
stats_text += f" Median: {result.stats['median']:.3f}\n"
stats_text += f" Range: [{result.stats['min']:.3f}, {result.stats['max']:.3f}]\n\n"
plt.text(0.1, 0.9, stats_text,
transform=plt.gca().transAxes,
verticalalignment='top',
fontfamily='monospace')
plt.axis('off')
print(stats_text)
plt.tight_layout()
plt.savefig('embedding_comparison.png')
plt.close()
async def main():
# Define different configurations to compare
configs = [
("Ollama Embedding", AppConfig(_env_file=Path('.env.evaluation.ollama'))),
("OpenAI Embedding", AppConfig(_env_file=Path('.env.evaluation.openai'))),
# Add more configurations as needed
]
# Analyze distributions for all configs
results = await analyze_embedding_distributions(
TRANSFORMER_TEST_CASES,
configs,
cache_path=Path("evaluations", "embedding_cache")
)
# Plot distributions
plot_distance_distributions(results)
return results
if __name__ == "__main__":
import asyncio
asyncio.run(main()) |