claude refactoring the evaluation visualization
Browse files
src/know_lang_bot/evaluation/chatbot_evaluation_visualize.py
CHANGED
@@ -3,21 +3,33 @@ import json
|
|
3 |
import pandas as pd
|
4 |
from rich.console import Console
|
5 |
from rich.table import Table
|
6 |
-
from typing import List
|
|
|
7 |
from know_lang_bot.evaluation.chatbot_evaluation import EvalSummary
|
8 |
|
|
|
|
|
|
|
|
|
|
|
9 |
class ResultAnalyzer:
|
10 |
-
def __init__(self, base_dir: Path):
|
11 |
self.console = Console()
|
12 |
-
self.
|
13 |
-
self.
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
def load_results(self, file_path: Path) -> List[EvalSummary]:
|
16 |
"""Load evaluation results from JSON file"""
|
17 |
with open(file_path) as f:
|
18 |
obj_list = json.load(f)
|
19 |
return [EvalSummary.model_validate(obj) for obj in obj_list]
|
20 |
-
|
21 |
def create_dataframe(self, results: List[EvalSummary]) -> pd.DataFrame:
|
22 |
"""Convert results to pandas DataFrame with flattened metrics"""
|
23 |
rows = []
|
@@ -29,116 +41,154 @@ class ResultAnalyzer:
|
|
29 |
"chunk_relevance": result.eval_response.chunk_relevance,
|
30 |
"answer_correctness": result.eval_response.answer_correctness,
|
31 |
"code_reference": result.eval_response.code_reference,
|
32 |
-
"weighted_total": result.eval_response.weighted_total
|
|
|
33 |
}
|
34 |
rows.append(row)
|
35 |
|
36 |
return pd.DataFrame(rows)
|
37 |
|
38 |
-
def
|
39 |
-
"""
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
all_results["reranking"].extend(self.load_results(file))
|
51 |
-
|
52 |
-
# Convert to DataFrames
|
53 |
-
embedding_df = self.create_dataframe(all_results["embedding"])
|
54 |
-
reranking_df = self.create_dataframe(all_results["reranking"])
|
55 |
-
|
56 |
-
# Calculate statistics by evaluator model
|
57 |
-
def get_model_stats(df: pd.DataFrame) -> pd.DataFrame:
|
58 |
-
return df.groupby("evaluator_model").agg({
|
59 |
-
"chunk_relevance": ["mean", "std"],
|
60 |
-
"answer_correctness": ["mean", "std"],
|
61 |
-
"code_reference": ["mean", "std"],
|
62 |
-
"weighted_total": ["mean", "std"]
|
63 |
-
}).round(2)
|
64 |
-
|
65 |
-
embedding_stats = get_model_stats(embedding_df)
|
66 |
-
reranking_stats = get_model_stats(reranking_df)
|
67 |
-
|
68 |
-
# Display comparison tables
|
69 |
-
self.display_comparison_table(embedding_stats, reranking_stats)
|
70 |
-
self.display_improvement_metrics(embedding_df, reranking_df)
|
71 |
-
|
72 |
-
# Save detailed results to CSV
|
73 |
-
self.save_detailed_results(embedding_df, reranking_df)
|
74 |
|
75 |
-
def
|
76 |
-
"""
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
|
|
79 |
table.add_column("Metric", style="cyan")
|
80 |
table.add_column("Model", style="magenta")
|
81 |
-
|
82 |
-
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
metrics = ["chunk_relevance", "answer_correctness", "code_reference", "weighted_total"]
|
86 |
-
|
|
|
87 |
for metric in metrics:
|
88 |
-
for model in
|
89 |
-
|
90 |
-
emb_std = embedding_stats.loc[model, (metric, "std")]
|
91 |
-
rer_mean = reranking_stats.loc[model, (metric, "mean")]
|
92 |
-
rer_std = reranking_stats.loc[model, (metric, "std")]
|
93 |
-
|
94 |
-
improvement = ((rer_mean - emb_mean) / emb_mean * 100).round(1)
|
95 |
-
|
96 |
-
table.add_row(
|
97 |
metric.replace("_", " ").title(),
|
98 |
-
model.split(":")[-1]
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
self.console.print(table)
|
105 |
|
106 |
-
def
|
107 |
-
"""Display
|
108 |
-
|
109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
|
|
|
|
|
|
121 |
|
122 |
-
table.add_row(
|
123 |
-
str(diff),
|
124 |
-
f"{emb_score:.2f}",
|
125 |
-
f"{rer_score:.2f}",
|
126 |
-
f"{improvement:+.1f}%"
|
127 |
-
)
|
128 |
|
129 |
self.console.print(table)
|
130 |
|
131 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
"""Save detailed results to CSV"""
|
133 |
-
#
|
134 |
-
|
135 |
-
|
|
|
|
|
|
|
136 |
|
137 |
-
|
138 |
-
|
139 |
-
combined_df.to_csv(
|
140 |
-
self.console.print(f"\nDetailed results saved to
|
141 |
|
142 |
if __name__ == "__main__":
|
143 |
-
analyzer = ResultAnalyzer(
|
|
|
|
|
|
|
144 |
analyzer.analyze_results()
|
|
|
3 |
import pandas as pd
|
4 |
from rich.console import Console
|
5 |
from rich.table import Table
|
6 |
+
from typing import List, Dict
|
7 |
+
from enum import Enum
|
8 |
from know_lang_bot.evaluation.chatbot_evaluation import EvalSummary
|
9 |
|
10 |
+
class RetrievalMethod(str, Enum):
|
11 |
+
EMBEDDING = "embedding"
|
12 |
+
EMBEDDING_RERANKING = "embedding_reranking"
|
13 |
+
EMBEDDING_WITH_CODE = "embedding_reranking_with_code"
|
14 |
+
|
15 |
class ResultAnalyzer:
|
16 |
+
def __init__(self, base_dir: Path, baseline_method: RetrievalMethod = RetrievalMethod.EMBEDDING):
|
17 |
self.console = Console()
|
18 |
+
self.base_dir = base_dir
|
19 |
+
self.baseline_method = baseline_method
|
20 |
+
# Map each method to its directory
|
21 |
+
self.method_dirs = {
|
22 |
+
RetrievalMethod.EMBEDDING: self.base_dir / RetrievalMethod.EMBEDDING.value,
|
23 |
+
RetrievalMethod.EMBEDDING_RERANKING: self.base_dir / RetrievalMethod.EMBEDDING_RERANKING.value,
|
24 |
+
RetrievalMethod.EMBEDDING_WITH_CODE: self.base_dir / RetrievalMethod.EMBEDDING_WITH_CODE.value
|
25 |
+
}
|
26 |
+
|
27 |
def load_results(self, file_path: Path) -> List[EvalSummary]:
|
28 |
"""Load evaluation results from JSON file"""
|
29 |
with open(file_path) as f:
|
30 |
obj_list = json.load(f)
|
31 |
return [EvalSummary.model_validate(obj) for obj in obj_list]
|
32 |
+
|
33 |
def create_dataframe(self, results: List[EvalSummary]) -> pd.DataFrame:
|
34 |
"""Convert results to pandas DataFrame with flattened metrics"""
|
35 |
rows = []
|
|
|
41 |
"chunk_relevance": result.eval_response.chunk_relevance,
|
42 |
"answer_correctness": result.eval_response.answer_correctness,
|
43 |
"code_reference": result.eval_response.code_reference,
|
44 |
+
"weighted_total": result.eval_response.weighted_total,
|
45 |
+
"environment": getattr(result.case, 'environment', 'default') # Added environment
|
46 |
}
|
47 |
rows.append(row)
|
48 |
|
49 |
return pd.DataFrame(rows)
|
50 |
|
51 |
+
def load_all_results(self) -> Dict[RetrievalMethod, pd.DataFrame]:
|
52 |
+
"""Load results for all available methods"""
|
53 |
+
results = {}
|
54 |
+
for method in RetrievalMethod:
|
55 |
+
method_dir = self.method_dirs.get(method)
|
56 |
+
if method_dir and method_dir.exists():
|
57 |
+
all_results = []
|
58 |
+
for file in method_dir.glob("*.json"):
|
59 |
+
all_results.extend(self.load_results(file))
|
60 |
+
if all_results: # Only include methods with results
|
61 |
+
results[method] = self.create_dataframe(all_results)
|
62 |
+
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
+
def calculate_improvement(self, new_val: float, baseline_val: float) -> str:
|
65 |
+
"""Calculate and format improvement percentage"""
|
66 |
+
if baseline_val == 0:
|
67 |
+
return "N/A"
|
68 |
+
improvement = ((new_val - baseline_val) / baseline_val * 100).round(1)
|
69 |
+
return f"{improvement:+.1f}%" if improvement else "0%"
|
70 |
+
|
71 |
+
def get_stats_by_group(self, df: pd.DataFrame, group_by: str) -> pd.DataFrame:
|
72 |
+
"""Calculate statistics grouped by specified column"""
|
73 |
+
return df.groupby(group_by).agg({
|
74 |
+
"chunk_relevance": ["mean", "std"],
|
75 |
+
"answer_correctness": ["mean", "std"],
|
76 |
+
"code_reference": ["mean", "std"],
|
77 |
+
"weighted_total": ["mean", "std"]
|
78 |
+
}).round(2)
|
79 |
+
|
80 |
+
def display_comparison_table(self, results: Dict[RetrievalMethod, pd.DataFrame]):
|
81 |
+
"""Display rich table comparing all methods"""
|
82 |
+
table = Table(title="Method Comparison by Evaluator Model")
|
83 |
|
84 |
+
# Add columns
|
85 |
table.add_column("Metric", style="cyan")
|
86 |
table.add_column("Model", style="magenta")
|
87 |
+
for method in results.keys():
|
88 |
+
table.add_column(method.value.replace("_", " ").title(), style="blue")
|
89 |
+
if method != self.baseline_method:
|
90 |
+
table.add_column(f"{method.value} Improvement", style="yellow")
|
91 |
+
|
92 |
+
# Calculate stats for each method
|
93 |
+
stats_by_method = {
|
94 |
+
method: self.get_stats_by_group(df, "evaluator_model")
|
95 |
+
for method, df in results.items()
|
96 |
+
}
|
97 |
|
98 |
metrics = ["chunk_relevance", "answer_correctness", "code_reference", "weighted_total"]
|
99 |
+
baseline_stats = stats_by_method[self.baseline_method]
|
100 |
+
|
101 |
for metric in metrics:
|
102 |
+
for model in baseline_stats.index:
|
103 |
+
row_data = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
metric.replace("_", " ").title(),
|
105 |
+
model.split(":")[-1]
|
106 |
+
]
|
107 |
+
|
108 |
+
# Add data for each method
|
109 |
+
for method in results.keys():
|
110 |
+
stats = stats_by_method[method]
|
111 |
+
mean = stats.loc[model, (metric, "mean")]
|
112 |
+
std = stats.loc[model, (metric, "std")]
|
113 |
+
row_data.append(f"{mean:.2f} ±{std:.2f}")
|
114 |
+
|
115 |
+
# Add improvement column if not baseline
|
116 |
+
if method != self.baseline_method:
|
117 |
+
baseline_mean = baseline_stats.loc[model, (metric, "mean")]
|
118 |
+
row_data.append(self.calculate_improvement(mean, baseline_mean))
|
119 |
+
|
120 |
+
table.add_row(*row_data)
|
121 |
|
122 |
self.console.print(table)
|
123 |
|
124 |
+
def display_environment_comparison(self, results: Dict[RetrievalMethod, pd.DataFrame]):
|
125 |
+
"""Display comparison across different evaluation environments"""
|
126 |
+
table = Table(title="Method Comparison by Environment")
|
127 |
+
|
128 |
+
table.add_column("Environment", style="cyan")
|
129 |
+
for method in results.keys():
|
130 |
+
table.add_column(method.value.replace("_", " ").title(), style="blue")
|
131 |
+
if method != self.baseline_method:
|
132 |
+
table.add_column(f"{method.value} Improvement", style="yellow")
|
133 |
+
|
134 |
+
# Get environments from all results
|
135 |
+
environments = sorted(set().union(*[
|
136 |
+
set(df["environment"].unique())
|
137 |
+
for df in results.values()
|
138 |
+
]))
|
139 |
+
|
140 |
+
baseline_df = results[self.baseline_method]
|
141 |
|
142 |
+
for env in environments:
|
143 |
+
row_data = [env]
|
144 |
+
|
145 |
+
for method in results.keys():
|
146 |
+
df = results[method]
|
147 |
+
env_score = df[df["environment"] == env]["weighted_total"].mean()
|
148 |
+
row_data.append(f"{env_score:.2f}")
|
149 |
+
|
150 |
+
if method != self.baseline_method:
|
151 |
+
baseline_score = baseline_df[
|
152 |
+
baseline_df["environment"] == env
|
153 |
+
]["weighted_total"].mean()
|
154 |
+
row_data.append(self.calculate_improvement(env_score, baseline_score))
|
155 |
|
156 |
+
table.add_row(*row_data)
|
|
|
|
|
|
|
|
|
|
|
157 |
|
158 |
self.console.print(table)
|
159 |
|
160 |
+
def analyze_results(self):
|
161 |
+
"""Analyze and display results comparison"""
|
162 |
+
results = self.load_all_results()
|
163 |
+
if not results:
|
164 |
+
self.console.print("[red]No results found!")
|
165 |
+
return
|
166 |
+
|
167 |
+
# Display comparisons
|
168 |
+
self.display_comparison_table(results)
|
169 |
+
self.console.print("\n")
|
170 |
+
self.display_environment_comparison(results)
|
171 |
+
|
172 |
+
# Save detailed results
|
173 |
+
self.save_detailed_results(results)
|
174 |
+
|
175 |
+
def save_detailed_results(self, results: Dict[RetrievalMethod, pd.DataFrame]):
|
176 |
"""Save detailed results to CSV"""
|
177 |
+
# Combine all results with method column
|
178 |
+
dfs = []
|
179 |
+
for method, df in results.items():
|
180 |
+
df = df.copy()
|
181 |
+
df["method"] = method.value
|
182 |
+
dfs.append(df)
|
183 |
|
184 |
+
combined_df = pd.concat(dfs)
|
185 |
+
output_path = self.base_dir / "evaluation_comparison.csv"
|
186 |
+
combined_df.to_csv(output_path, index=False)
|
187 |
+
self.console.print(f"\nDetailed results saved to {output_path}")
|
188 |
|
189 |
if __name__ == "__main__":
|
190 |
+
analyzer = ResultAnalyzer(
|
191 |
+
Path("evaluations"),
|
192 |
+
baseline_method=RetrievalMethod.EMBEDDING
|
193 |
+
)
|
194 |
analyzer.analyze_results()
|