transformers evaluation test cases draft
Browse files
src/know_lang_bot/evaluation/chatbot_evaluation.py
ADDED
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Dict, Optional
|
2 |
+
from enum import Enum
|
3 |
+
from pydantic import BaseModel, Field
|
4 |
+
from pydantic_ai import Agent
|
5 |
+
from know_lang_bot.config import AppConfig
|
6 |
+
from know_lang_bot.utils.model_provider import create_pydantic_model
|
7 |
+
from know_lang_bot.chat_bot.chat_graph import ChatResult
|
8 |
+
import json
|
9 |
+
import asyncio
|
10 |
+
|
11 |
+
class EvalMetric(str, Enum):
|
12 |
+
CHUNK_RELEVANCE = "chunk_relevance"
|
13 |
+
ANSWER_CORRECTNESS = "answer_correctness"
|
14 |
+
CODE_REFERENCE = "code_reference"
|
15 |
+
|
16 |
+
class EvalCase(BaseModel):
|
17 |
+
"""Single evaluation case focused on code understanding"""
|
18 |
+
question: str
|
19 |
+
expected_files: List[str] = Field(description="Files that should be in retrieved chunks")
|
20 |
+
expected_concepts: List[str] = Field(description="Key concepts that should be in answer")
|
21 |
+
expected_code_refs: List[str] = Field(description="Code references that should be mentioned")
|
22 |
+
difficulty: int = Field(ge=1, le=3, description="1: Easy, 2: Medium, 3: Hard")
|
23 |
+
|
24 |
+
class EvalResult(BaseModel):
|
25 |
+
"""Evaluation result with scores and feedback"""
|
26 |
+
case: EvalCase
|
27 |
+
metrics: Dict[EvalMetric, float]
|
28 |
+
total_score: float
|
29 |
+
feedback: str
|
30 |
+
polished_question: Optional[str] = None
|
31 |
+
|
32 |
+
class EvalAgentResponse(BaseModel):
|
33 |
+
chunk_relevance: float
|
34 |
+
answer_correctness: float
|
35 |
+
code_reference: float
|
36 |
+
feedback: str
|
37 |
+
|
38 |
+
|
39 |
+
class ChatBotEvaluationContext(EvalCase, ChatResult):
|
40 |
+
pass
|
41 |
+
|
42 |
+
|
43 |
+
class ChatBotEvaluator:
|
44 |
+
def __init__(self, config: AppConfig):
|
45 |
+
"""Initialize evaluator with app config"""
|
46 |
+
self.config = config
|
47 |
+
self.eval_agent = Agent(
|
48 |
+
create_pydantic_model(
|
49 |
+
model_provider=config.llm.model_provider,
|
50 |
+
model_name=config.llm.model_name
|
51 |
+
),
|
52 |
+
system_prompt=self._build_eval_prompt(),
|
53 |
+
result_type=EvalAgentResponse
|
54 |
+
)
|
55 |
+
|
56 |
+
def _build_eval_prompt(self) -> str:
|
57 |
+
return """You are an expert evaluator of code understanding systems.
|
58 |
+
Evaluate the response based on these specific criteria:
|
59 |
+
|
60 |
+
1. Chunk Relevance (0-1):
|
61 |
+
- Are the retrieved code chunks from the expected files?
|
62 |
+
- Do they contain relevant code sections?
|
63 |
+
|
64 |
+
2. Answer Correctness (0-1):
|
65 |
+
- Does the answer accurately explain the code?
|
66 |
+
- Are the expected concepts covered?
|
67 |
+
|
68 |
+
3. Code Reference Quality (0-1):
|
69 |
+
- Does it properly cite specific code locations?
|
70 |
+
- Are code references clear and relevant?
|
71 |
+
}"""
|
72 |
+
|
73 |
+
async def evaluate_single(
|
74 |
+
self,
|
75 |
+
case: EvalCase,
|
76 |
+
chat_result: ChatResult
|
77 |
+
) -> EvalResult:
|
78 |
+
"""Evaluate a single case"""
|
79 |
+
# Prepare evaluation context
|
80 |
+
eval_context = ChatBotEvaluationContext(
|
81 |
+
**case.model_dump(),
|
82 |
+
**chat_result.model_dump()
|
83 |
+
)
|
84 |
+
|
85 |
+
# Get evaluation from the model
|
86 |
+
result = await self.eval_agent.run(
|
87 |
+
json.dumps(eval_context),
|
88 |
+
)
|
89 |
+
metrics = result.data
|
90 |
+
|
91 |
+
# Calculate weighted score
|
92 |
+
weights = {
|
93 |
+
EvalMetric.CHUNK_RELEVANCE: 0.4,
|
94 |
+
EvalMetric.ANSWER_CORRECTNESS: 0.4,
|
95 |
+
EvalMetric.CODE_REFERENCE: 0.2
|
96 |
+
}
|
97 |
+
|
98 |
+
total_score = sum(
|
99 |
+
metrics[metric] * weights[metric] * case.difficulty
|
100 |
+
for metric in EvalMetric
|
101 |
+
)
|
102 |
+
|
103 |
+
return EvalResult(
|
104 |
+
case=case,
|
105 |
+
metrics=metrics,
|
106 |
+
total_score=total_score,
|
107 |
+
feedback=metrics["feedback"]
|
108 |
+
)
|
109 |
+
|
110 |
+
async def evaluate_batch(
|
111 |
+
self,
|
112 |
+
cases: List[EvalCase],
|
113 |
+
process_chat_func,
|
114 |
+
max_concurrent: int = 2
|
115 |
+
) -> List[EvalResult]:
|
116 |
+
"""Run evaluation on multiple cases with concurrency control"""
|
117 |
+
semaphore = asyncio.Semaphore(max_concurrent)
|
118 |
+
|
119 |
+
async def eval_single_with_limit(case: EvalCase) -> EvalResult:
|
120 |
+
async with semaphore:
|
121 |
+
chat_result = await process_chat_func(case.question)
|
122 |
+
return await self.evaluate_single(case, chat_result)
|
123 |
+
|
124 |
+
return await asyncio.gather(
|
125 |
+
*[eval_single_with_limit(case) for case in cases]
|
126 |
+
)
|
127 |
+
|
128 |
+
# src/transformers/quantizers/base.py
|
129 |
+
TRANSFORMER_QUANTIZER_BASE_CASES = [
|
130 |
+
EvalCase(
|
131 |
+
question= "How are different quantization methods implemented in the transformers library, and what are the key components required to implement a new quantization method?",
|
132 |
+
expected_files= ["quantizers/base.py"],
|
133 |
+
expected_concepts= [
|
134 |
+
"HfQuantizer abstract base class",
|
135 |
+
"PreTrainedModel quantization",
|
136 |
+
"pre/post processing of models",
|
137 |
+
"quantization configuration",
|
138 |
+
"requires_calibration flag"
|
139 |
+
],
|
140 |
+
expected_code_refs= [
|
141 |
+
"class HfQuantizer",
|
142 |
+
"preprocess_model method",
|
143 |
+
"postprocess_model method",
|
144 |
+
"_process_model_before_weight_loading",
|
145 |
+
"requires_calibration attribute"
|
146 |
+
],
|
147 |
+
difficulty= 3
|
148 |
+
)
|
149 |
+
]
|
150 |
+
|
151 |
+
# src/transformers/quantizers/auto.py
|
152 |
+
TRANSFORMER_QUANTIZER_AUTO_CASES = [
|
153 |
+
EvalCase(
|
154 |
+
question="How does the transformers library automatically select and configure the appropriate quantization method, and what happens when loading a pre-quantized model?",
|
155 |
+
expected_files=[
|
156 |
+
"quantizers/auto.py",
|
157 |
+
"utils/quantization_config.py"
|
158 |
+
],
|
159 |
+
expected_concepts=[
|
160 |
+
"automatic quantizer selection",
|
161 |
+
"quantization config mapping",
|
162 |
+
"config merging behavior",
|
163 |
+
"backwards compatibility for bitsandbytes",
|
164 |
+
"quantization method resolution"
|
165 |
+
],
|
166 |
+
expected_code_refs=[
|
167 |
+
"AUTO_QUANTIZER_MAPPING",
|
168 |
+
"AUTO_QUANTIZATION_CONFIG_MAPPING",
|
169 |
+
"AutoHfQuantizer.from_config",
|
170 |
+
"AutoQuantizationConfig.from_pretrained",
|
171 |
+
"merge_quantization_configs method"
|
172 |
+
],
|
173 |
+
difficulty=3
|
174 |
+
)
|
175 |
+
]
|
176 |
+
|
177 |
+
|
178 |
+
# src/transformers/pipelines/base.py
|
179 |
+
TRANSFORMER_PIPELINE_BASE_TEST_CASES = [
|
180 |
+
EvalCase(
|
181 |
+
question="How does the Pipeline class handle model and device initialization?",
|
182 |
+
expected_files=["base.py"],
|
183 |
+
expected_concepts=[
|
184 |
+
"device placement",
|
185 |
+
"model initialization",
|
186 |
+
"framework detection",
|
187 |
+
"device type detection",
|
188 |
+
"torch dtype handling"
|
189 |
+
],
|
190 |
+
expected_code_refs=[
|
191 |
+
"def __init__",
|
192 |
+
"def device_placement",
|
193 |
+
"infer_framework_load_model",
|
194 |
+
"self.device = torch.device"
|
195 |
+
],
|
196 |
+
difficulty=3
|
197 |
+
),
|
198 |
+
EvalCase(
|
199 |
+
question="How does the Pipeline class implement batched inference and data loading?",
|
200 |
+
expected_files=["base.py", "pt_utils.py"],
|
201 |
+
expected_concepts=[
|
202 |
+
"batch processing",
|
203 |
+
"data loading",
|
204 |
+
"collate function",
|
205 |
+
"padding implementation",
|
206 |
+
"iterator pattern"
|
207 |
+
],
|
208 |
+
expected_code_refs=[
|
209 |
+
"def get_iterator",
|
210 |
+
"class PipelineDataset",
|
211 |
+
"class PipelineIterator",
|
212 |
+
"_pad",
|
213 |
+
"pad_collate_fn"
|
214 |
+
],
|
215 |
+
difficulty=3
|
216 |
+
)
|
217 |
+
]
|
218 |
+
|
219 |
+
# src/transformers/pipelines/text_generation.py
|
220 |
+
TRANSFORMER_PIPELINE_TEXT_GENERATION_TEST_CASES = [
|
221 |
+
EvalCase(
|
222 |
+
question="How does the TextGenerationPipeline handle chat-based generation and template processing?",
|
223 |
+
expected_files=["text_generation.py", "base.py"],
|
224 |
+
expected_concepts=[
|
225 |
+
"chat message formatting",
|
226 |
+
"template application",
|
227 |
+
"message continuation",
|
228 |
+
"role handling",
|
229 |
+
"assistant prefill behavior"
|
230 |
+
],
|
231 |
+
expected_code_refs=[
|
232 |
+
"class Chat",
|
233 |
+
"tokenizer.apply_chat_template",
|
234 |
+
"continue_final_message",
|
235 |
+
"isinstance(prompt_text, Chat)",
|
236 |
+
"postprocess"
|
237 |
+
],
|
238 |
+
difficulty=3
|
239 |
+
)
|
240 |
+
]
|
241 |
+
|
242 |
+
# src/transformers/generation/logits_process.py
|
243 |
+
TRANSFORMER_LOGITS_PROCESSOR_TEST_CASES = [
|
244 |
+
EvalCase(
|
245 |
+
question="How does TopKLogitsWarper implement top-k filtering for text generation?",
|
246 |
+
expected_files=["generation/logits_process.py"],
|
247 |
+
expected_concepts=[
|
248 |
+
"top-k filtering algorithm",
|
249 |
+
"probability masking",
|
250 |
+
"batch processing",
|
251 |
+
"logits manipulation",
|
252 |
+
"vocabulary filtering"
|
253 |
+
],
|
254 |
+
expected_code_refs=[
|
255 |
+
"class TopKLogitsWarper(LogitsProcessor)",
|
256 |
+
"torch.topk(scores, top_k)[0]",
|
257 |
+
"indices_to_remove = scores < torch.topk",
|
258 |
+
"scores_processed = scores.masked_fill(indices_to_remove, self.filter_value)",
|
259 |
+
"top_k = max(top_k, min_tokens_to_keep)"
|
260 |
+
],
|
261 |
+
difficulty=3
|
262 |
+
),
|
263 |
+
EvalCase(
|
264 |
+
question="How does TemperatureLogitsProcessor implement temperature sampling for controlling generation randomness?",
|
265 |
+
expected_files=["generation/logits_process.py"],
|
266 |
+
expected_concepts=[
|
267 |
+
"temperature scaling",
|
268 |
+
"probability distribution shaping",
|
269 |
+
"logits normalization",
|
270 |
+
"generation randomness control",
|
271 |
+
"batch processing with temperature"
|
272 |
+
],
|
273 |
+
expected_code_refs=[
|
274 |
+
"class TemperatureLogitsProcessor(LogitsProcessor)",
|
275 |
+
"scores_processed = scores / self.temperature",
|
276 |
+
"if not isinstance(temperature, float) or not (temperature > 0)",
|
277 |
+
"def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor)",
|
278 |
+
"raise ValueError(except_msg)"
|
279 |
+
],
|
280 |
+
difficulty=3
|
281 |
+
)
|
282 |
+
]
|
283 |
+
|
284 |
+
# src/transformers/trainer.py
|
285 |
+
TRANSFORMER_TRAINER_TEST_CASES = [
|
286 |
+
EvalCase(
|
287 |
+
question="How does Trainer handle distributed training and gradient accumulation? Explain the implementation details.",
|
288 |
+
expected_files=["trainer.py"],
|
289 |
+
expected_concepts=[
|
290 |
+
"gradient accumulation steps",
|
291 |
+
"distributed training logic",
|
292 |
+
"optimizer step scheduling",
|
293 |
+
"loss scaling",
|
294 |
+
"device synchronization"
|
295 |
+
],
|
296 |
+
expected_code_refs=[
|
297 |
+
"def training_step",
|
298 |
+
"def _wrap_model",
|
299 |
+
"self.accelerator.backward",
|
300 |
+
"self.args.gradient_accumulation_steps",
|
301 |
+
"if args.n_gpu > 1",
|
302 |
+
"model.zero_grad()"
|
303 |
+
],
|
304 |
+
difficulty=3
|
305 |
+
),
|
306 |
+
EvalCase(
|
307 |
+
question="How does the Trainer class implement custom optimizer and learning rate scheduler creation? Explain the initialization process and supported configurations.",
|
308 |
+
expected_files=["trainer.py"],
|
309 |
+
expected_concepts=[
|
310 |
+
"optimizer initialization",
|
311 |
+
"learning rate scheduler",
|
312 |
+
"weight decay handling",
|
313 |
+
"optimizer parameter groups",
|
314 |
+
"AdamW configuration",
|
315 |
+
"custom optimizer support"
|
316 |
+
],
|
317 |
+
expected_code_refs=[
|
318 |
+
"def create_optimizer",
|
319 |
+
"def create_scheduler",
|
320 |
+
"get_decay_parameter_names",
|
321 |
+
"optimizer_grouped_parameters",
|
322 |
+
"self.args.learning_rate",
|
323 |
+
"optimizer_kwargs"
|
324 |
+
],
|
325 |
+
difficulty=3
|
326 |
+
)
|
327 |
+
]
|
328 |
+
|
329 |
+
TRANSFORMER_TEST_CASES = [
|
330 |
+
*TRANSFORMER_QUANTIZER_BASE_CASES,
|
331 |
+
*TRANSFORMER_QUANTIZER_AUTO_CASES,
|
332 |
+
*TRANSFORMER_PIPELINE_BASE_TEST_CASES,
|
333 |
+
*TRANSFORMER_PIPELINE_TEXT_GENERATION_TEST_CASES,
|
334 |
+
*TRANSFORMER_LOGITS_PROCESSOR_TEST_CASES,
|
335 |
+
*TRANSFORMER_TRAINER_TEST_CASES,
|
336 |
+
]
|