Spaces:
Running
Running
File size: 4,337 Bytes
4c7982b 044ed98 4c7982b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
from dataclasses import dataclass, field
from datasets import load_dataset, Dataset
from functools import cached_property
from tqdm.auto import tqdm
from typing import Any, Optional, Protocol, Iterable, Callable
from .utils import (
NUMERIC_IN_ZH,
extract_choice_ans,
extract_numeric,
get_answer,
is_equiv,
)
from evaluate import load
TextGenerationPipeline = Callable[[Iterable[str]], list[str]]
def fake_pipeline(prompts: Iterable[str]) -> list[str]:
return [prompt for prompt in tqdm(prompts)]
@dataclass
class Task:
dataset_name: str | tuple[str, str] = ("gsm8k", "main")
split: str = "test"
# metrics: list[str] = field(default_factory=list)
metric_name: str | tuple[str, str] = ("sustech/tlem", "gsm8k")
input_column: str = "question"
label_column: str = "answer"
prompt: Optional[Callable | str] = None
@cached_property
def name(self):
return (
self.dataset_name
if isinstance(self.dataset_name, str)
else self.dataset_name[0]
) + f"-{self.split}"
@cached_property
def samples(self):
return self.dataset[self.input_column]
@cached_property
def dataset(self):
ds = load_dataset(
*self.dataset_name
if isinstance(self.dataset_name, tuple)
else self.dataset_name,
split=self.split,
)
if self.prompt is not None:
ds = ds.map(
lambda example: {
self.input_column: self.prompt.format(
input_column=example[self.input_column]
)
}
if isinstance(self.prompt, str)
else self.prompt(example),
)
return ds
@cached_property
def metric(self):
metric = (
load(self.metric_name)
if isinstance(self.metric_name, str)
else load(*self.metric_name)
)
return metric
def run(self, pipeline: TextGenerationPipeline = fake_pipeline):
outputs = pipeline(self.samples)
return self.metric.compute(
responses=outputs, references=self.dataset[self.label_column]
)
class Metrics:
def gsm8k(responses: list[str], answers: list[str | int]):
scores = []
for response, answer in zip(responses, answers):
pred = extract_numeric(response)
gold = extract_numeric(answer) if isinstance(answer, str) else str(answer)
scores.append(1.0 * (pred == gold))
return scores
def MATH(responses: list[str], answers: list[str]):
scores = []
for response, answer in zip(responses, answers):
indices = [pos for pos, char in enumerate(response) if char == "$"]
if len(indices) <= 2:
scores.append(0)
continue
else:
result = response[indices[-2] + 1 : indices[-1]]
gold = get_answer(answer)
scores.append(1.0 * is_equiv(result, gold))
return scores
def math23k(responses: list[str], answers: list[str]):
scores = []
for response, answer in zip(responses, answers):
pred = extract_numeric(response, pattern=NUMERIC_IN_ZH)
gold = extract_numeric(answer, pattern=NUMERIC_IN_ZH)
scores.append(1.0 * (pred == gold))
return scores
def gsm8k_zh(responses: list[str], answers: list[str]):
scores = []
for response, answer in zip(responses, answers):
pred = extract_numeric(response, pattern=NUMERIC_IN_ZH)
gold = extract_numeric(answer)
scores.append(1.0 * (pred == gold))
return scores
def svamp(responses: list[float], answers: list[str]):
scores = []
for response, answer in zip(responses, answers):
pred = extract_numeric(response, pattern=NUMERIC_IN_ZH)
gold = answer
scores.append(1.0 * (float(pred) == gold))
return scores
def mmlu(responses, answers):
scores = []
for response, answer in zip(responses, answers):
pred = extract_choice_ans(response)
gold = answer.lower()
scores.append(1.0 * (pred == gold))
return scores
|