JeffYang52415's picture
feat: add bbh/mmlu parser
44529bb unverified
raw
history blame
12.8 kB
from dataclasses import dataclass
from typing import Any, Final
from llmdataparser.base_parser import HuggingFaceDatasetParser, HuggingFaceParseEntry
from llmdataparser.prompts import MMLU_PRO_SYSTEM_PROMPT, MMLU_SYSTEM_PROMPT
MMLU_VALID_ANSWERS: Final[set[str]] = {"A", "B", "C", "D"}
MMLU_PRO_VALID_ANSWERS: Final[set[str]] = {
"A",
"B",
"C",
"D",
"E",
"F",
"G",
"H",
"I",
"J",
}
MMLU_VALID_ANSWER_STR: Final[str] = ", ".join(sorted(MMLU_VALID_ANSWERS))
MMLU_PRO_VALID_ANSWER_STR: Final[str] = ", ".join(sorted(MMLU_PRO_VALID_ANSWERS))
@dataclass(frozen=True, kw_only=True, slots=True)
class MMLUParseEntry(HuggingFaceParseEntry):
"""Custom entry class for MMLU, with fields specific to this dataset parser."""
raw_choices: list[str]
task_name: str
@classmethod
def create(
cls,
prompt: str,
answer: str,
raw_question: str,
raw_choices: list[str],
raw_answer: str,
task_name: str,
) -> "MMLUParseEntry":
if answer not in MMLU_VALID_ANSWERS:
raise ValueError(
f"Invalid answer_letter '{answer}'; must be one of {MMLU_VALID_ANSWER_STR}"
)
if not task_name:
raise ValueError("Task name cannot be empty")
return cls(
prompt=prompt,
answer=answer,
raw_question=raw_question,
raw_answer=raw_answer,
raw_choices=raw_choices,
task_name=task_name,
)
@dataclass(frozen=True, kw_only=True, slots=True)
class MMLUProParseEntry(HuggingFaceParseEntry):
"""Custom entry class for MMLU, with fields specific to this dataset parser."""
raw_choices: list[str]
task_name: str
@classmethod
def create(
cls,
prompt: str,
answer: str,
raw_question: str,
raw_choices: list[str],
raw_answer: str,
task_name: str,
) -> "MMLUProParseEntry":
if answer not in MMLU_PRO_VALID_ANSWERS:
raise ValueError(
f"Invalid answer_letter '{answer}'; must be one of {MMLU_PRO_VALID_ANSWER_STR}"
)
if not task_name:
raise ValueError("Task name cannot be empty")
return cls(
prompt=prompt,
answer=answer,
raw_question=raw_question,
raw_choices=raw_choices,
raw_answer=raw_answer,
task_name=task_name,
)
class MMLUDatasetParser(HuggingFaceDatasetParser[MMLUParseEntry]):
"""Base class for MMLU dataset parsers with common functionality."""
_default_system_prompt = MMLU_SYSTEM_PROMPT
def _get_task_from_entry(self, data_entry: dict[str, Any]) -> str:
"""Get the task name from the data entry or default task name."""
task_name = data_entry.get("subject")
return task_name if task_name else (self._current_task or self._default_task)
def process_entry(
self, row: dict[str, Any], task_name: str | None = None, **kwargs: Any
) -> MMLUParseEntry:
"""
Generate a prompt and expected answer from the given row.
Args:
row: A data point to be formatted.
task_name: Optional task name for the entry.
**kwargs: Additional keyword arguments.
Returns:
MMLUParseEntry: The formatted entry object.
"""
task = task_name or self._get_current_task(row)
# Ensure task is not None
final_task = task or self._default_task
choices = "\n".join(
f"{chr(65 + i)}. {choice}" for i, choice in enumerate(row["choices"])
)
raw_question = row["question"]
raw_choices = row["choices"]
raw_answer = str(row["answer"]) # Ensure raw_answer is a string
prompt = f"{self._system_prompt}\nQuestion: {raw_question}\n{choices}\nAnswer:"
answer_letter = chr(65 + int(raw_answer)) # Convert index to 'A', 'B', 'C', 'D'
return MMLUParseEntry.create(
prompt=prompt,
answer=answer_letter,
raw_question=raw_question,
raw_choices=raw_choices,
raw_answer=raw_answer,
task_name=final_task,
)
class BaseMMLUDatasetParser(MMLUDatasetParser):
"""Parser for the original MMLU dataset."""
_data_source = "cais/mmlu"
_default_task = "all"
_task_names = [
"abstract_algebra",
"anatomy",
"astronomy",
"business_ethics",
"clinical_knowledge",
"college_biology",
"college_chemistry",
"college_computer_science",
"college_mathematics",
"college_medicine",
"college_physics",
"computer_security",
"conceptual_physics",
"econometrics",
"electrical_engineering",
"elementary_mathematics",
"formal_logic",
"global_facts",
"high_school_biology",
"high_school_chemistry",
"high_school_computer_science",
"high_school_european_history",
"high_school_geography",
"high_school_government_and_politics",
"high_school_macroeconomics",
"high_school_mathematics",
"high_school_microeconomics",
"high_school_physics",
"high_school_psychology",
"high_school_statistics",
"high_school_us_history",
"high_school_world_history",
"human_aging",
"human_sexuality",
"international_law",
"jurisprudence",
"logical_fallacies",
"machine_learning",
"management",
"marketing",
"medical_genetics",
"miscellaneous",
"moral_disputes",
"moral_scenarios",
"nutrition",
"philosophy",
"prehistory",
"professional_accounting",
"professional_law",
"professional_medicine",
"professional_psychology",
"public_relations",
"security_studies",
"sociology",
"us_foreign_policy",
"virology",
"world_religions",
]
class MMLUReduxDatasetParser(MMLUDatasetParser):
"""Parser for the MMLU Redux dataset."""
_data_source = "edinburgh-dawg/mmlu-redux"
_default_task = "anatomy"
_task_names = [
"anatomy",
"astronomy",
"business_ethics",
"clinical_knowledge",
"college_chemistry",
"college_computer_science",
"college_mathematics",
"college_medicine",
"college_physics",
"conceptual_physics",
"econometrics",
"electrical_engineering",
"formal_logic",
"global_facts",
"high_school_chemistry",
"high_school_geography",
"high_school_macroeconomics",
"high_school_mathematics",
"high_school_physics",
"high_school_statistics",
"high_school_us_history",
"human_aging",
"logical_fallacies",
"machine_learning",
"miscellaneous",
"philosophy",
"professional_accounting",
"professional_law",
"public_relations",
"virology",
]
class TMMLUPlusDatasetParser(MMLUDatasetParser):
"""Parser for the TMMLU+ dataset."""
_data_source = "ikala/tmmluplus"
_default_task = "taiwanese_hokkien"
_task_names = [
"engineering_math",
"dentistry",
"traditional_chinese_medicine_clinical_medicine",
"clinical_psychology",
"technical",
"culinary_skills",
"mechanical",
"logic_reasoning",
"real_estate",
"general_principles_of_law",
"finance_banking",
"anti_money_laundering",
"ttqav2",
"marketing_management",
"business_management",
"organic_chemistry",
"advance_chemistry",
"physics",
"secondary_physics",
"human_behavior",
"national_protection",
"jce_humanities",
"politic_science",
"agriculture",
"official_document_management",
"financial_analysis",
"pharmacy",
"educational_psychology",
"statistics_and_machine_learning",
"management_accounting",
"introduction_to_law",
"computer_science",
"veterinary_pathology",
"accounting",
"fire_science",
"optometry",
"insurance_studies",
"pharmacology",
"taxation",
"trust_practice",
"geography_of_taiwan",
"physical_education",
"auditing",
"administrative_law",
"education_(profession_level)",
"economics",
"veterinary_pharmacology",
"nautical_science",
"occupational_therapy_for_psychological_disorders",
"basic_medical_science",
"macroeconomics",
"trade",
"chinese_language_and_literature",
"tve_design",
"junior_science_exam",
"junior_math_exam",
"junior_chinese_exam",
"junior_social_studies",
"tve_mathematics",
"tve_chinese_language",
"tve_natural_sciences",
"junior_chemistry",
"music",
"education",
"three_principles_of_people",
"taiwanese_hokkien",
]
def process_entry(
self, row: dict[str, Any], task_name: str | None = None, **kwargs: Any
) -> MMLUParseEntry:
"""Process a single TMMLU+ entry."""
# Extract choices in order
raw_choices = [row["A"], row["B"], row["C"], row["D"]]
choices = "\n".join(
f"{chr(65 + i)}. {choice}" for i, choice in enumerate(raw_choices)
)
raw_question = row["question"]
raw_answer = row["answer"]
prompt = f"{self._system_prompt}\nQuestion: {raw_question}\n{choices}\nAnswer:"
task = task_name or self._get_current_task(row)
return MMLUParseEntry.create(
prompt, raw_answer, raw_question, raw_choices, raw_answer, task
)
class MMLUProDatasetParser(HuggingFaceDatasetParser[MMLUProParseEntry]):
"""Parser for the MMLU Pro dataset."""
_data_source = "TIGER-Lab/MMLU-Pro"
_default_task = "default"
_task_names = [
"math",
"physics",
"chemistry",
"law",
"engineering",
"other",
"economics",
"health",
"psychology",
"business",
"biology",
"philosophy",
"computer_science",
"history",
]
_default_system_prompt = MMLU_PRO_SYSTEM_PROMPT
def _get_task_from_entry(self, data_entry: dict[str, Any]) -> str:
"""Get the task name from the data entry or default task name."""
if data_entry is not None:
task_name = data_entry.get("category")
if task_name:
return task_name
return self._current_task or self._default_task
def process_entry(
self, row: dict[str, Any], task_name: str | None = None, **kwargs: Any
) -> MMLUProParseEntry:
"""
Generate a prompt and expected answer from the given row.
Args:
row (dict[str, Any]): A data point to be formatted with MMLU Pro specific structure
containing 'question', 'options', 'answer', and 'answer_index' keys.
Returns:
MMLUParseEntry: The formatted entry object.
"""
task = task_name or self._get_current_task(row)
# Ensure task is not None
final_task = task or self._default_task
# Extract choices in order
raw_choices = row["options"]
choices = "\n".join(
f"{chr(65 + i)}. {choice}" for i, choice in enumerate(raw_choices)
)
raw_question = row["question"]
raw_answer = row["answer"]
answer_index = row["answer_index"]
prompt = f"{self._system_prompt}\nQuestion: {raw_question}\n{choices}\nAnswer:"
answer_letter = chr(
65 + answer_index
) # Convert index to 'A', 'B', 'C', 'D', etc.
return MMLUProParseEntry.create(
prompt, answer_letter, raw_question, raw_choices, raw_answer, final_task
)
if __name__ == "__main__":
# Example usage of MMLU Pro parser
parser = MMLUProDatasetParser()
parser.load()
parser.parse()
# Get parsed data with correct type
parsed_data = parser.get_parsed_data
# Print example entry
if parsed_data:
example = parsed_data[0]
print("\nExample parsed entry:")
print(f"Task: {example.task_name}")
print(f"Question: {example.raw_question}")
print("Choices:")
for i, choice in enumerate(example.raw_choices):
print(f"{chr(65 + i)}. {choice}")
print(f"Correct Answer: {example.answer}")