File size: 2,752 Bytes
3da1d9d
9d5b4c0
066c396
47125b5
066c396
9d5b4c0
0a1b314
066c396
 
 
 
 
 
 
 
 
3da1d9d
0a1b314
066c396
 
 
 
9d5b4c0
066c396
 
b462f85
 
 
 
 
7f6dcb7
 
 
 
b462f85
f6ebc4f
 
b462f85
f6ebc4f
b462f85
 
f6ebc4f
9d5b4c0
b462f85
 
3da1d9d
7f6dcb7
066c396
f4655a2
 
066c396
 
f4655a2
 
 
 
9d5b4c0
 
 
 
 
 
 
 
 
 
066c396
 
7f6dcb7
066c396
7f6dcb7
 
 
 
066c396
7f6dcb7
 
066c396
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import json
from typing import Any, Dict, Optional

from datasets import Features, Sequence, Value

from .artifact import Artifact
from .operator import InstanceOperatorValidator

UNITXT_DATASET_SCHEMA = Features(
    {
        "source": Value("string"),
        "target": Value("string"),
        "references": Sequence(Value("string")),
        "metrics": Sequence(Value("string")),
        "group": Value("string"),
        "postprocessors": Sequence(Value("string")),
        "task_data": Value(dtype="string"),
        "data_classification_policy": Sequence(Value("string")),
    }
)


class Finalize(InstanceOperatorValidator):
    remove_unnecessary_fields: bool = True

    @staticmethod
    def artifact_to_jsonable(artifact):
        if artifact.__id__ is None:
            return artifact.to_dict()
        return artifact.__id__

    def process(
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
    ) -> Dict[str, Any]:
        task_data = {
            **instance["input_fields"],
            **instance["reference_fields"],
            "metadata": {
                "data_classification_policy": instance["data_classification_policy"],
                "template": self.artifact_to_jsonable(
                    instance["recipe_metadata"]["template"]
                ),
                "num_demos": instance["recipe_metadata"]["num_demos"],
            },
        }
        instance["task_data"] = json.dumps(task_data)

        if self.remove_unnecessary_fields:
            keys_to_delete = []

            for key in instance.keys():
                if key not in UNITXT_DATASET_SCHEMA:
                    keys_to_delete.append(key)

            for key in keys_to_delete:
                del instance[key]
        if "group" not in instance:
            instance["group"] = "unitxt"
        instance["metrics"] = [
            metric.to_json() if isinstance(metric, Artifact) else metric
            for metric in instance["metrics"]
        ]
        instance["postprocessors"] = [
            processor.to_json() if isinstance(processor, Artifact) else processor
            for processor in instance["postprocessors"]
        ]
        return instance

    def validate(self, instance: Dict[str, Any], stream_name: Optional[str] = None):
        # verify the instance has the required schema
        assert instance is not None, "Instance is None"
        assert isinstance(
            instance, dict
        ), f"Instance should be a dict, got {type(instance)}"
        assert all(
            key in instance for key in UNITXT_DATASET_SCHEMA
        ), f"Instance should have the following keys: {UNITXT_DATASET_SCHEMA}. Instance is: {instance}"
        UNITXT_DATASET_SCHEMA.encode_example(instance)