File size: 2,752 Bytes
3da1d9d 9d5b4c0 066c396 47125b5 066c396 9d5b4c0 0a1b314 066c396 3da1d9d 0a1b314 066c396 9d5b4c0 066c396 b462f85 7f6dcb7 b462f85 f6ebc4f b462f85 f6ebc4f b462f85 f6ebc4f 9d5b4c0 b462f85 3da1d9d 7f6dcb7 066c396 f4655a2 066c396 f4655a2 9d5b4c0 066c396 7f6dcb7 066c396 7f6dcb7 066c396 7f6dcb7 066c396 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import json
from typing import Any, Dict, Optional
from datasets import Features, Sequence, Value
from .artifact import Artifact
from .operator import InstanceOperatorValidator
UNITXT_DATASET_SCHEMA = Features(
{
"source": Value("string"),
"target": Value("string"),
"references": Sequence(Value("string")),
"metrics": Sequence(Value("string")),
"group": Value("string"),
"postprocessors": Sequence(Value("string")),
"task_data": Value(dtype="string"),
"data_classification_policy": Sequence(Value("string")),
}
)
class Finalize(InstanceOperatorValidator):
remove_unnecessary_fields: bool = True
@staticmethod
def artifact_to_jsonable(artifact):
if artifact.__id__ is None:
return artifact.to_dict()
return artifact.__id__
def process(
self, instance: Dict[str, Any], stream_name: Optional[str] = None
) -> Dict[str, Any]:
task_data = {
**instance["input_fields"],
**instance["reference_fields"],
"metadata": {
"data_classification_policy": instance["data_classification_policy"],
"template": self.artifact_to_jsonable(
instance["recipe_metadata"]["template"]
),
"num_demos": instance["recipe_metadata"]["num_demos"],
},
}
instance["task_data"] = json.dumps(task_data)
if self.remove_unnecessary_fields:
keys_to_delete = []
for key in instance.keys():
if key not in UNITXT_DATASET_SCHEMA:
keys_to_delete.append(key)
for key in keys_to_delete:
del instance[key]
if "group" not in instance:
instance["group"] = "unitxt"
instance["metrics"] = [
metric.to_json() if isinstance(metric, Artifact) else metric
for metric in instance["metrics"]
]
instance["postprocessors"] = [
processor.to_json() if isinstance(processor, Artifact) else processor
for processor in instance["postprocessors"]
]
return instance
def validate(self, instance: Dict[str, Any], stream_name: Optional[str] = None):
# verify the instance has the required schema
assert instance is not None, "Instance is None"
assert isinstance(
instance, dict
), f"Instance should be a dict, got {type(instance)}"
assert all(
key in instance for key in UNITXT_DATASET_SCHEMA
), f"Instance should have the following keys: {UNITXT_DATASET_SCHEMA}. Instance is: {instance}"
UNITXT_DATASET_SCHEMA.encode_example(instance)
|