metric / task.py
Elron's picture
Upload folder using huggingface_hub
a350a45 verified
raw
history blame
6.41 kB
from typing import Any, Dict, List, Optional, Union
from .artifact import fetch_artifact
from .logging_utils import get_logger
from .operator import StreamInstanceOperator
from .type_utils import (
get_args,
get_origin,
isoftype,
parse_type_string,
verify_required_schema,
)
class Tasker:
pass
class FormTask(Tasker, StreamInstanceOperator):
"""FormTask packs the different instance fields into dictionaries by their roles in the task.
Attributes:
inputs (Union[Dict[str, str], List[str]]):
Dictionary with string names of instance input fields and types of respective values.
In case a list is passed, each type will be assumed to be Any.
outputs (Union[Dict[str, str], List[str]]):
Dictionary with string names of instance output fields and types of respective values.
In case a list is passed, each type will be assumed to be Any.
metrics (List[str]): List of names of metrics to be used in the task.
prediction_type (Optional[str]):
Need to be consistent with all used metrics. Defaults to None, which means that it will
be set to Any.
The output instance contains three fields:
"inputs" whose value is a sub-dictionary of the input instance, consisting of all the fields listed in Arg 'inputs'.
"outputs" -- for the fields listed in Arg "outputs".
"metrics" -- to contain the value of Arg 'metrics'
"""
inputs: Union[Dict[str, str], List[str]]
outputs: Union[Dict[str, str], List[str]]
metrics: List[str]
prediction_type: Optional[str] = None
augmentable_inputs: List[str] = []
def verify(self):
for io_type in ["inputs", "outputs"]:
data = self.inputs if io_type == "inputs" else self.outputs
if not isoftype(data, Dict[str, str]):
get_logger().warning(
f"'{io_type}' field of Task should be a dictionary of field names and their types. "
f"For example, {{'text': 'str', 'classes': 'List[str]'}}. Instead only '{data}' was "
f"passed. All types will be assumed to be 'Any'. In future version of unitxt this "
f"will raise an exception."
)
data = {key: "Any" for key in data}
if io_type == "inputs":
self.inputs = data
else:
self.outputs = data
if not self.prediction_type:
get_logger().warning(
"'prediction_type' was not set in Task. It is used to check the output of "
"template post processors is compatible with the expected input of the metrics. "
"Setting `prediction_type` to 'Any' (no checking is done). In future version "
"of unitxt this will raise an exception."
)
self.prediction_type = "Any"
self.check_metrics_type()
for augmentable_input in self.augmentable_inputs:
assert (
augmentable_input in self.inputs
), f"augmentable_input {augmentable_input} is not part of {self.inputs}"
def check_metrics_type(self) -> None:
prediction_type = parse_type_string(self.prediction_type)
for metric_name in self.metrics:
metric = fetch_artifact(metric_name)[0]
metric_prediction_type = metric.get_prediction_type()
if (
prediction_type == metric_prediction_type
or prediction_type == Any
or metric_prediction_type == Any
or (
get_origin(metric_prediction_type) is Union
and prediction_type in get_args(metric_prediction_type)
)
):
continue
raise ValueError(
f"The task's prediction type ({prediction_type}) and '{metric_name}' "
f"metric's prediction type ({metric_prediction_type}) are different."
)
def process(
self, instance: Dict[str, Any], stream_name: Optional[str] = None
) -> Dict[str, Any]:
verify_required_schema(self.inputs, instance)
verify_required_schema(self.outputs, instance)
inputs = {key: instance[key] for key in self.inputs.keys()}
outputs = {key: instance[key] for key in self.outputs.keys()}
return {
"inputs": inputs,
"outputs": outputs,
"metrics": self.metrics,
}
class MultipleChoiceTask(FormTask):
choices_field: str = "choices"
choices_separator: str = "\n"
enumeration_suffix: str = ". "
use_text_in_target: bool = False
alphabet: str = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
def process_single_choice(
self, choice: str, index: int, use_text: bool = True
) -> str:
try:
processed_choice = f"{self.alphabet[index]}"
except IndexError as e:
raise ValueError(
f"Too many choices, the length of alphabet '{self.alphabet}': {len(self.alphabet)} is the limit"
) from e
if use_text:
processed_choice += f"{self.enumeration_suffix}{choice}"
return processed_choice
def process_choices(self, choices: List[str]) -> str:
processed_choices = []
for index, choice in enumerate(choices):
processed_choices.append(self.process_single_choice(choice, index))
return self.choices_separator.join(processed_choices)
def process_target(self, choices, target_index):
return self.process_single_choice(
choices[target_index], target_index, use_text=self.use_text_in_target
)
def process(
self, instance: Dict[str, Any], stream_name: Optional[str] = None
) -> Dict[str, Any]:
result = super().process(instance, stream_name)
target_key, target_value = next(iter(result["outputs"].items()))
choices = result["inputs"][self.choices_field]
target_index_in_choices = choices.index(target_value)
processed_choices = self.process_choices(choices)
processed_target = self.process_target(choices, target_index_in_choices)
result["inputs"][self.choices_field] = processed_choices
result["outputs"][target_key] = processed_target
return result