File size: 6,318 Bytes
18e4b60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import torch
import evaluate
import random

from unimernet.common.registry import registry
from unimernet.tasks.base_task import BaseTask
from unimernet.common.dist_utils import main_process
import os.path as osp
import json
import numpy as np
from torchtext.data import metrics
from rapidfuzz.distance import Levenshtein


@registry.register_task("unimernet_train")
class UniMERNet_Train(BaseTask):

    def __init__(self, temperature, do_sample, top_p, evaluate, report_metric=True, agg_metric="edit_distance"):
        super(UniMERNet_Train, self).__init__()
        self.temperature = temperature
        self.do_sample = do_sample
        self.top_p = top_p
        self.evaluate = evaluate
        self.agg_metric = agg_metric

        self.report_metric = report_metric

    @classmethod
    def setup_task(cls, cfg):
        run_cfg = cfg.run_cfg
        generate_cfg = run_cfg.generate_cfg

        temperature = generate_cfg.get('temperature', .2)
        do_sample = generate_cfg.get("do_sample", False)
        top_p = generate_cfg.get("top_p", 0.95)

        evaluate = run_cfg.evaluate
        report_metric = run_cfg.get("report_metric", True)
        agg_metric = run_cfg.get("agg_metric", "edit_distance")

        return cls(
            temperature=temperature,
            do_sample=do_sample,
            top_p=top_p,
            evaluate=evaluate,
            report_metric=report_metric,
            agg_metric=agg_metric,
        )

    def valid_step(self, model, samples):
        results = []
        image, text = samples["image"], samples["text_input"]
        preds = model.generate(
            samples,
            temperature=self.temperature,
            do_sample=self.do_sample,
            top_p=self.top_p
        )
        pred_tokens = preds["pred_tokens"]
        pred_strs = preds["pred_str"]
        pred_ids = preds["pred_ids"]  # [b, n-1]

        truth_inputs = model.tokenizer.tokenize(text)
        truth_ids = truth_inputs["input_ids"][:, 1:]
        truth_tokens = model.tokenizer.detokenize(truth_inputs["input_ids"])
        truth_strs = model.tokenizer.token2str(truth_inputs["input_ids"])

        ids = samples["id"]

        for pred_token, pred_str, pred_id, truth_token, truth_str, truth_id, id_ in zip(pred_tokens, pred_strs,
                                                                                        pred_ids, truth_tokens,
                                                                                        truth_strs, truth_ids, ids):
            pred_id = pred_id.tolist()
            truth_id = truth_id.tolist()
            shape_diff = len(pred_id) - len(truth_id)
            if shape_diff < 0:
                pred_id = pred_id + [model.tokenizer.pad_token_id] * (-shape_diff)
            else:
                truth_id = truth_id + [model.tokenizer.pad_token_id] * shape_diff
            pred_id, truth_id = torch.LongTensor(pred_id), torch.LongTensor(truth_id)
            mask = torch.logical_or(pred_id != model.tokenizer.pad_token_id, truth_id != model.tokenizer.pad_token_id)
            tok_acc = (pred_id == truth_id)[mask].float().mean().item()

            this_item = {
                "pred_token": pred_token,
                "pred_str": pred_str,
                "truth_str": truth_str,
                "truth_token": truth_token,
                "token_acc": tok_acc,
                "id": id_
            }
            results.append(this_item)
        return results

    def after_evaluation(self, val_result, split_name, epoch, **kwargs):
        eval_result_file = self.save_result(
            result=val_result,
            result_dir=registry.get_path("result_dir"),
            filename="{}_epoch{}".format(split_name, epoch),
            remove_duplicate="id",
        )

        if self.report_metric:
            metrics = self._report_metrics(
                eval_result_file=eval_result_file, split_name=split_name
            )
        else:
            metrics = {"agg_metrics": 0.0}

        return metrics

    @main_process
    def _report_metrics(self, eval_result_file, split_name):

        with open(eval_result_file) as f:
            results = json.load(f)

        edit_dists = []
        all_pred_tokens = []
        all_truth_tokens = []
        all_pred_strs = []
        all_truth_strs = []
        token_accs = []
        for result in results:
            pred_token, pred_str, truth_token, truth_str, tok_acc = result["pred_token"], result["pred_str"], result[
                "truth_token"], result["truth_str"], result["token_acc"]

            if len(truth_str) > 0:
                norm_edit_dist = Levenshtein.normalized_distance(pred_str, truth_str)
                edit_dists.append(norm_edit_dist)

            all_pred_tokens.append(pred_token)
            all_truth_tokens.append([truth_token])
            all_pred_strs.append(pred_str)
            all_truth_strs.append(truth_str)
            token_accs.append(tok_acc)

        # bleu_score = metrics.bleu_score(all_pred_tokens, all_truth_tokens)
        bleu = evaluate.load("bleu", keep_in_memory=True, experiment_id=random.randint(1, 1e8))
        bleu_results = bleu.compute(predictions=all_pred_strs, references=all_truth_strs)
        bleu_score = bleu_results['bleu']
        
        edit_distance = np.mean(edit_dists)
        token_accuracy = np.mean(token_accs)
        eval_ret = {"bleu": bleu_score, "edit_distance": edit_distance, "token_accuracy": token_accuracy}

        log_stats = {split_name: {k: v for k, v in eval_ret.items()}}

        with open(
                osp.join(registry.get_path("output_dir"), "evaluate.txt"), "a"
        ) as f:
            f.write(json.dumps(log_stats) + "\n")

        coco_res = {k: v for k, v in eval_ret.items()}
        # agg_metrics = sum([v for v in eval_ret.values()])
        if "edit" in self.agg_metric.lower():  # edit_distance
            agg_metrics = (1 - edit_distance) * 100
        elif "bleu" in self.agg_metric.lower():  # bleu_score
            agg_metrics = bleu_score * 100
        elif "token" in self.agg_metric.lower():  # token_accuracy
            agg_metrics = token_accuracy * 100
        else:
            raise ValueError(f"Invalid metrics: '{self.agg_metric}'")

        coco_res["agg_metrics"] = agg_metrics

        return coco_res