evaluation refacotring for better data model
Browse files
src/know_lang_bot/evaluation/chatbot_evaluation.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
from typing import List, Dict, Optional
|
2 |
from enum import Enum
|
3 |
-
from pydantic import BaseModel, Field
|
4 |
from pydantic_ai import Agent
|
5 |
from know_lang_bot.config import AppConfig
|
6 |
from know_lang_bot.utils.model_provider import create_pydantic_model
|
@@ -20,25 +20,42 @@ class EvalCase(BaseModel):
|
|
20 |
expected_code_refs: List[str] = Field(description="Code references that should be mentioned")
|
21 |
difficulty: int = Field(ge=1, le=3, description="1: Easy, 2: Medium, 3: Hard")
|
22 |
|
23 |
-
class EvalResult(BaseModel):
|
24 |
-
"""Evaluation result with scores and feedback"""
|
25 |
-
evaluator_model: str
|
26 |
-
case: EvalCase
|
27 |
-
metrics: Dict[EvalMetric, float]
|
28 |
-
total_score: float
|
29 |
-
feedback: str
|
30 |
-
polished_question: Optional[str] = None
|
31 |
|
32 |
-
class
|
33 |
chunk_relevance: float = Field(ge=0.0, le=10.0, description="Score for chunk relevance")
|
34 |
answer_correctness: float = Field(ge=0.0, le=10.0, description="Score for answer correctness")
|
35 |
code_reference: float = Field(ge=0.0, le=10.0, description="Score for code reference quality")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
feedback: str
|
37 |
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
class ChatBotEvaluationContext(EvalCase, ChatResult):
|
40 |
pass
|
41 |
|
|
|
|
|
|
|
|
|
42 |
|
43 |
class ChatBotEvaluator:
|
44 |
def __init__(self, config: AppConfig):
|
@@ -95,29 +112,10 @@ Format your response as JSON:
|
|
95 |
eval_context.model_dump_json(),
|
96 |
)
|
97 |
eval_response : EvalAgentResponse = result.data
|
98 |
-
metrics = {
|
99 |
-
EvalMetric.CHUNK_RELEVANCE: eval_response.chunk_relevance,
|
100 |
-
EvalMetric.ANSWER_CORRECTNESS: eval_response.answer_correctness,
|
101 |
-
EvalMetric.CODE_REFERENCE: eval_response.code_reference
|
102 |
-
}
|
103 |
-
|
104 |
-
# Calculate weighted score
|
105 |
-
weights = {
|
106 |
-
EvalMetric.CHUNK_RELEVANCE: 0.4,
|
107 |
-
EvalMetric.ANSWER_CORRECTNESS: 0.4,
|
108 |
-
EvalMetric.CODE_REFERENCE: 0.2
|
109 |
-
}
|
110 |
-
|
111 |
-
total_score = sum(
|
112 |
-
metrics[metric] * weights[metric] * case.difficulty
|
113 |
-
for metric in EvalMetric
|
114 |
-
)
|
115 |
|
116 |
return EvalResult(
|
117 |
case=case,
|
118 |
-
|
119 |
-
total_score=total_score,
|
120 |
-
feedback=eval_response.feedback,
|
121 |
evaluator_model=f"{self.config.evaluator.model_provider}:{self.config.evaluator.model_name}"
|
122 |
)
|
123 |
|
@@ -360,29 +358,29 @@ async def main():
|
|
360 |
evaluator = ChatBotEvaluator(config)
|
361 |
collection = chromadb.PersistentClient(path=str(config.db.persist_directory)).get_collection(name=config.db.collection_name)
|
362 |
|
363 |
-
|
364 |
|
365 |
for case in TRANSFORMER_TEST_CASES:
|
366 |
try:
|
367 |
-
chat_result = await process_chat(question=case.question, collection=collection, config=config)
|
368 |
-
result = await evaluator.evaluate_single(case, chat_result)
|
369 |
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
final_results.append(aggregated_result)
|
377 |
-
|
378 |
except Exception:
|
379 |
console.print_exception()
|
380 |
|
381 |
# Write the final JSON array to a file
|
382 |
with open("evaluation_results.json", "w") as f:
|
383 |
-
|
|
|
|
|
384 |
|
385 |
-
console.print(Pretty(
|
386 |
|
387 |
if __name__ == "__main__":
|
388 |
asyncio.run(main())
|
|
|
1 |
from typing import List, Dict, Optional
|
2 |
from enum import Enum
|
3 |
+
from pydantic import BaseModel, Field, computed_field
|
4 |
from pydantic_ai import Agent
|
5 |
from know_lang_bot.config import AppConfig
|
6 |
from know_lang_bot.utils.model_provider import create_pydantic_model
|
|
|
20 |
expected_code_refs: List[str] = Field(description="Code references that should be mentioned")
|
21 |
difficulty: int = Field(ge=1, le=3, description="1: Easy, 2: Medium, 3: Hard")
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
+
class MetricScores(BaseModel):
|
25 |
chunk_relevance: float = Field(ge=0.0, le=10.0, description="Score for chunk relevance")
|
26 |
answer_correctness: float = Field(ge=0.0, le=10.0, description="Score for answer correctness")
|
27 |
code_reference: float = Field(ge=0.0, le=10.0, description="Score for code reference quality")
|
28 |
+
|
29 |
+
@computed_field
|
30 |
+
def weighted_total(self) -> float:
|
31 |
+
"""Calculate weighted total score"""
|
32 |
+
weights = {
|
33 |
+
"chunk_relevance": 0.4,
|
34 |
+
"answer_correctness": 0.4,
|
35 |
+
"code_reference": 0.2
|
36 |
+
}
|
37 |
+
return sum(
|
38 |
+
getattr(self, metric) * weight
|
39 |
+
for metric, weight in weights.items()
|
40 |
+
)
|
41 |
+
|
42 |
+
class EvalAgentResponse(MetricScores):
|
43 |
+
"""Raw response from evaluation agent"""
|
44 |
feedback: str
|
45 |
|
46 |
+
class EvalResult(BaseModel):
|
47 |
+
"""Evaluation result with scores and feedback"""
|
48 |
+
evaluator_model: str
|
49 |
+
case: EvalCase
|
50 |
+
eval_response: EvalAgentResponse
|
51 |
|
52 |
class ChatBotEvaluationContext(EvalCase, ChatResult):
|
53 |
pass
|
54 |
|
55 |
+
class EvalSummary(EvalResult, ChatResult):
|
56 |
+
"""Evaluation summary with chat and evaluation results"""
|
57 |
+
pass
|
58 |
+
|
59 |
|
60 |
class ChatBotEvaluator:
|
61 |
def __init__(self, config: AppConfig):
|
|
|
112 |
eval_context.model_dump_json(),
|
113 |
)
|
114 |
eval_response : EvalAgentResponse = result.data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
|
116 |
return EvalResult(
|
117 |
case=case,
|
118 |
+
eval_response=eval_response,
|
|
|
|
|
119 |
evaluator_model=f"{self.config.evaluator.model_provider}:{self.config.evaluator.model_name}"
|
120 |
)
|
121 |
|
|
|
358 |
evaluator = ChatBotEvaluator(config)
|
359 |
collection = chromadb.PersistentClient(path=str(config.db.persist_directory)).get_collection(name=config.db.collection_name)
|
360 |
|
361 |
+
summary_list : List[EvalSummary] = []
|
362 |
|
363 |
for case in TRANSFORMER_TEST_CASES:
|
364 |
try:
|
365 |
+
chat_result : ChatResult = await process_chat(question=case.question, collection=collection, config=config)
|
366 |
+
result : EvalResult = await evaluator.evaluate_single(case, chat_result)
|
367 |
|
368 |
+
eval_summary = EvalSummary(
|
369 |
+
**chat_result.model_dump(),
|
370 |
+
**result.model_dump()
|
371 |
+
)
|
372 |
+
summary_list.append(eval_summary)
|
373 |
+
|
|
|
|
|
374 |
except Exception:
|
375 |
console.print_exception()
|
376 |
|
377 |
# Write the final JSON array to a file
|
378 |
with open("evaluation_results.json", "w") as f:
|
379 |
+
json_list = [summary.model_dump() for summary in summary_list]
|
380 |
+
json.dump(json_list, f, indent=2)
|
381 |
+
|
382 |
|
383 |
+
console.print(Pretty(summary_list))
|
384 |
|
385 |
if __name__ == "__main__":
|
386 |
asyncio.run(main())
|