|
from dataclasses import dataclass |
|
from typing import Any, Final |
|
|
|
from llmdataparser.base_parser import HuggingFaceDatasetParser, HuggingFaceParseEntry |
|
from llmdataparser.prompts import MMLU_PRO_SYSTEM_PROMPT, MMLU_SYSTEM_PROMPT |
|
|
|
MMLU_VALID_ANSWERS: Final[set[str]] = {"A", "B", "C", "D"} |
|
MMLU_PRO_VALID_ANSWERS: Final[set[str]] = { |
|
"A", |
|
"B", |
|
"C", |
|
"D", |
|
"E", |
|
"F", |
|
"G", |
|
"H", |
|
"I", |
|
"J", |
|
} |
|
MMLU_VALID_ANSWER_STR: Final[str] = ", ".join(sorted(MMLU_VALID_ANSWERS)) |
|
MMLU_PRO_VALID_ANSWER_STR: Final[str] = ", ".join(sorted(MMLU_PRO_VALID_ANSWERS)) |
|
|
|
|
|
@dataclass(frozen=True, kw_only=True, slots=True) |
|
class MMLUParseEntry(HuggingFaceParseEntry): |
|
"""Custom entry class for MMLU, with fields specific to this dataset parser.""" |
|
|
|
raw_choices: list[str] |
|
task_name: str |
|
|
|
@classmethod |
|
def create( |
|
cls, |
|
prompt: str, |
|
answer: str, |
|
raw_question: str, |
|
raw_choices: list[str], |
|
raw_answer: str, |
|
task_name: str, |
|
) -> "MMLUParseEntry": |
|
if answer not in MMLU_VALID_ANSWERS: |
|
raise ValueError( |
|
f"Invalid answer_letter '{answer}'; must be one of {MMLU_VALID_ANSWER_STR}" |
|
) |
|
if not task_name: |
|
raise ValueError("Task name cannot be empty") |
|
return cls( |
|
prompt=prompt, |
|
answer=answer, |
|
raw_question=raw_question, |
|
raw_answer=raw_answer, |
|
raw_choices=raw_choices, |
|
task_name=task_name, |
|
) |
|
|
|
|
|
@dataclass(frozen=True, kw_only=True, slots=True) |
|
class MMLUProParseEntry(HuggingFaceParseEntry): |
|
"""Custom entry class for MMLU, with fields specific to this dataset parser.""" |
|
|
|
raw_choices: list[str] |
|
task_name: str |
|
|
|
@classmethod |
|
def create( |
|
cls, |
|
prompt: str, |
|
answer: str, |
|
raw_question: str, |
|
raw_choices: list[str], |
|
raw_answer: str, |
|
task_name: str, |
|
) -> "MMLUProParseEntry": |
|
if answer not in MMLU_PRO_VALID_ANSWERS: |
|
raise ValueError( |
|
f"Invalid answer_letter '{answer}'; must be one of {MMLU_PRO_VALID_ANSWER_STR}" |
|
) |
|
if not task_name: |
|
raise ValueError("Task name cannot be empty") |
|
return cls( |
|
prompt=prompt, |
|
answer=answer, |
|
raw_question=raw_question, |
|
raw_choices=raw_choices, |
|
raw_answer=raw_answer, |
|
task_name=task_name, |
|
) |
|
|
|
|
|
class MMLUDatasetParser(HuggingFaceDatasetParser[MMLUParseEntry]): |
|
"""Base class for MMLU dataset parsers with common functionality.""" |
|
|
|
_default_system_prompt = MMLU_SYSTEM_PROMPT |
|
|
|
def _get_task_from_entry(self, data_entry: dict[str, Any]) -> str: |
|
"""Get the task name from the data entry or default task name.""" |
|
task_name = data_entry.get("subject") |
|
return task_name if task_name else (self._current_task or self._default_task) |
|
|
|
def process_entry( |
|
self, row: dict[str, Any], task_name: str | None = None, **kwargs: Any |
|
) -> MMLUParseEntry: |
|
""" |
|
Generate a prompt and expected answer from the given row. |
|
|
|
Args: |
|
row: A data point to be formatted. |
|
task_name: Optional task name for the entry. |
|
**kwargs: Additional keyword arguments. |
|
|
|
Returns: |
|
MMLUParseEntry: The formatted entry object. |
|
""" |
|
task = task_name or self._get_current_task(row) |
|
|
|
final_task = task or self._default_task |
|
|
|
choices = "\n".join( |
|
f"{chr(65 + i)}. {choice}" for i, choice in enumerate(row["choices"]) |
|
) |
|
raw_question = row["question"] |
|
raw_choices = row["choices"] |
|
raw_answer = str(row["answer"]) |
|
|
|
prompt = f"{self._system_prompt}\nQuestion: {raw_question}\n{choices}\nAnswer:" |
|
answer_letter = chr(65 + int(raw_answer)) |
|
|
|
return MMLUParseEntry.create( |
|
prompt=prompt, |
|
answer=answer_letter, |
|
raw_question=raw_question, |
|
raw_choices=raw_choices, |
|
raw_answer=raw_answer, |
|
task_name=final_task, |
|
) |
|
|
|
|
|
class BaseMMLUDatasetParser(MMLUDatasetParser): |
|
"""Parser for the original MMLU dataset.""" |
|
|
|
_data_source = "cais/mmlu" |
|
_default_task = "all" |
|
_task_names = [ |
|
"abstract_algebra", |
|
"anatomy", |
|
"astronomy", |
|
"business_ethics", |
|
"clinical_knowledge", |
|
"college_biology", |
|
"college_chemistry", |
|
"college_computer_science", |
|
"college_mathematics", |
|
"college_medicine", |
|
"college_physics", |
|
"computer_security", |
|
"conceptual_physics", |
|
"econometrics", |
|
"electrical_engineering", |
|
"elementary_mathematics", |
|
"formal_logic", |
|
"global_facts", |
|
"high_school_biology", |
|
"high_school_chemistry", |
|
"high_school_computer_science", |
|
"high_school_european_history", |
|
"high_school_geography", |
|
"high_school_government_and_politics", |
|
"high_school_macroeconomics", |
|
"high_school_mathematics", |
|
"high_school_microeconomics", |
|
"high_school_physics", |
|
"high_school_psychology", |
|
"high_school_statistics", |
|
"high_school_us_history", |
|
"high_school_world_history", |
|
"human_aging", |
|
"human_sexuality", |
|
"international_law", |
|
"jurisprudence", |
|
"logical_fallacies", |
|
"machine_learning", |
|
"management", |
|
"marketing", |
|
"medical_genetics", |
|
"miscellaneous", |
|
"moral_disputes", |
|
"moral_scenarios", |
|
"nutrition", |
|
"philosophy", |
|
"prehistory", |
|
"professional_accounting", |
|
"professional_law", |
|
"professional_medicine", |
|
"professional_psychology", |
|
"public_relations", |
|
"security_studies", |
|
"sociology", |
|
"us_foreign_policy", |
|
"virology", |
|
"world_religions", |
|
] |
|
|
|
|
|
class MMLUReduxDatasetParser(MMLUDatasetParser): |
|
"""Parser for the MMLU Redux dataset.""" |
|
|
|
_data_source = "edinburgh-dawg/mmlu-redux" |
|
_default_task = "anatomy" |
|
_task_names = [ |
|
"anatomy", |
|
"astronomy", |
|
"business_ethics", |
|
"clinical_knowledge", |
|
"college_chemistry", |
|
"college_computer_science", |
|
"college_mathematics", |
|
"college_medicine", |
|
"college_physics", |
|
"conceptual_physics", |
|
"econometrics", |
|
"electrical_engineering", |
|
"formal_logic", |
|
"global_facts", |
|
"high_school_chemistry", |
|
"high_school_geography", |
|
"high_school_macroeconomics", |
|
"high_school_mathematics", |
|
"high_school_physics", |
|
"high_school_statistics", |
|
"high_school_us_history", |
|
"human_aging", |
|
"logical_fallacies", |
|
"machine_learning", |
|
"miscellaneous", |
|
"philosophy", |
|
"professional_accounting", |
|
"professional_law", |
|
"public_relations", |
|
"virology", |
|
] |
|
|
|
|
|
class TMMLUPlusDatasetParser(MMLUDatasetParser): |
|
"""Parser for the TMMLU+ dataset.""" |
|
|
|
_data_source = "ikala/tmmluplus" |
|
_default_task = "taiwanese_hokkien" |
|
_task_names = [ |
|
"engineering_math", |
|
"dentistry", |
|
"traditional_chinese_medicine_clinical_medicine", |
|
"clinical_psychology", |
|
"technical", |
|
"culinary_skills", |
|
"mechanical", |
|
"logic_reasoning", |
|
"real_estate", |
|
"general_principles_of_law", |
|
"finance_banking", |
|
"anti_money_laundering", |
|
"ttqav2", |
|
"marketing_management", |
|
"business_management", |
|
"organic_chemistry", |
|
"advance_chemistry", |
|
"physics", |
|
"secondary_physics", |
|
"human_behavior", |
|
"national_protection", |
|
"jce_humanities", |
|
"politic_science", |
|
"agriculture", |
|
"official_document_management", |
|
"financial_analysis", |
|
"pharmacy", |
|
"educational_psychology", |
|
"statistics_and_machine_learning", |
|
"management_accounting", |
|
"introduction_to_law", |
|
"computer_science", |
|
"veterinary_pathology", |
|
"accounting", |
|
"fire_science", |
|
"optometry", |
|
"insurance_studies", |
|
"pharmacology", |
|
"taxation", |
|
"trust_practice", |
|
"geography_of_taiwan", |
|
"physical_education", |
|
"auditing", |
|
"administrative_law", |
|
"education_(profession_level)", |
|
"economics", |
|
"veterinary_pharmacology", |
|
"nautical_science", |
|
"occupational_therapy_for_psychological_disorders", |
|
"basic_medical_science", |
|
"macroeconomics", |
|
"trade", |
|
"chinese_language_and_literature", |
|
"tve_design", |
|
"junior_science_exam", |
|
"junior_math_exam", |
|
"junior_chinese_exam", |
|
"junior_social_studies", |
|
"tve_mathematics", |
|
"tve_chinese_language", |
|
"tve_natural_sciences", |
|
"junior_chemistry", |
|
"music", |
|
"education", |
|
"three_principles_of_people", |
|
"taiwanese_hokkien", |
|
] |
|
|
|
def process_entry( |
|
self, row: dict[str, Any], task_name: str | None = None, **kwargs: Any |
|
) -> MMLUParseEntry: |
|
"""Process a single TMMLU+ entry.""" |
|
|
|
raw_choices = [row["A"], row["B"], row["C"], row["D"]] |
|
choices = "\n".join( |
|
f"{chr(65 + i)}. {choice}" for i, choice in enumerate(raw_choices) |
|
) |
|
raw_question = row["question"] |
|
raw_answer = row["answer"] |
|
|
|
prompt = f"{self._system_prompt}\nQuestion: {raw_question}\n{choices}\nAnswer:" |
|
task = task_name or self._get_current_task(row) |
|
|
|
return MMLUParseEntry.create( |
|
prompt, raw_answer, raw_question, raw_choices, raw_answer, task |
|
) |
|
|
|
|
|
class MMLUProDatasetParser(HuggingFaceDatasetParser[MMLUProParseEntry]): |
|
"""Parser for the MMLU Pro dataset.""" |
|
|
|
_data_source = "TIGER-Lab/MMLU-Pro" |
|
_default_task = "default" |
|
_task_names = [ |
|
"math", |
|
"physics", |
|
"chemistry", |
|
"law", |
|
"engineering", |
|
"other", |
|
"economics", |
|
"health", |
|
"psychology", |
|
"business", |
|
"biology", |
|
"philosophy", |
|
"computer_science", |
|
"history", |
|
] |
|
_default_system_prompt = MMLU_PRO_SYSTEM_PROMPT |
|
|
|
def _get_task_from_entry(self, data_entry: dict[str, Any]) -> str: |
|
"""Get the task name from the data entry or default task name.""" |
|
if data_entry is not None: |
|
task_name = data_entry.get("category") |
|
if task_name: |
|
return task_name |
|
return self._current_task or self._default_task |
|
|
|
def process_entry( |
|
self, row: dict[str, Any], task_name: str | None = None, **kwargs: Any |
|
) -> MMLUProParseEntry: |
|
""" |
|
Generate a prompt and expected answer from the given row. |
|
|
|
Args: |
|
row (dict[str, Any]): A data point to be formatted with MMLU Pro specific structure |
|
containing 'question', 'options', 'answer', and 'answer_index' keys. |
|
|
|
Returns: |
|
MMLUParseEntry: The formatted entry object. |
|
""" |
|
task = task_name or self._get_current_task(row) |
|
|
|
final_task = task or self._default_task |
|
|
|
|
|
raw_choices = row["options"] |
|
choices = "\n".join( |
|
f"{chr(65 + i)}. {choice}" for i, choice in enumerate(raw_choices) |
|
) |
|
raw_question = row["question"] |
|
raw_answer = row["answer"] |
|
answer_index = row["answer_index"] |
|
|
|
prompt = f"{self._system_prompt}\nQuestion: {raw_question}\n{choices}\nAnswer:" |
|
answer_letter = chr( |
|
65 + answer_index |
|
) |
|
|
|
return MMLUProParseEntry.create( |
|
prompt, answer_letter, raw_question, raw_choices, raw_answer, final_task |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
parser = MMLUProDatasetParser() |
|
parser.load() |
|
parser.parse() |
|
|
|
|
|
parsed_data = parser.get_parsed_data |
|
|
|
|
|
if parsed_data: |
|
example = parsed_data[0] |
|
print("\nExample parsed entry:") |
|
print(f"Task: {example.task_name}") |
|
print(f"Question: {example.raw_question}") |
|
print("Choices:") |
|
for i, choice in enumerate(example.raw_choices): |
|
print(f"{chr(65 + i)}. {choice}") |
|
print(f"Correct Answer: {example.answer}") |
|
|