Spaces:
Runtime error
Runtime error
File size: 6,318 Bytes
18e4b60 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import torch
import evaluate
import random
from unimernet.common.registry import registry
from unimernet.tasks.base_task import BaseTask
from unimernet.common.dist_utils import main_process
import os.path as osp
import json
import numpy as np
from torchtext.data import metrics
from rapidfuzz.distance import Levenshtein
@registry.register_task("unimernet_train")
class UniMERNet_Train(BaseTask):
def __init__(self, temperature, do_sample, top_p, evaluate, report_metric=True, agg_metric="edit_distance"):
super(UniMERNet_Train, self).__init__()
self.temperature = temperature
self.do_sample = do_sample
self.top_p = top_p
self.evaluate = evaluate
self.agg_metric = agg_metric
self.report_metric = report_metric
@classmethod
def setup_task(cls, cfg):
run_cfg = cfg.run_cfg
generate_cfg = run_cfg.generate_cfg
temperature = generate_cfg.get('temperature', .2)
do_sample = generate_cfg.get("do_sample", False)
top_p = generate_cfg.get("top_p", 0.95)
evaluate = run_cfg.evaluate
report_metric = run_cfg.get("report_metric", True)
agg_metric = run_cfg.get("agg_metric", "edit_distance")
return cls(
temperature=temperature,
do_sample=do_sample,
top_p=top_p,
evaluate=evaluate,
report_metric=report_metric,
agg_metric=agg_metric,
)
def valid_step(self, model, samples):
results = []
image, text = samples["image"], samples["text_input"]
preds = model.generate(
samples,
temperature=self.temperature,
do_sample=self.do_sample,
top_p=self.top_p
)
pred_tokens = preds["pred_tokens"]
pred_strs = preds["pred_str"]
pred_ids = preds["pred_ids"] # [b, n-1]
truth_inputs = model.tokenizer.tokenize(text)
truth_ids = truth_inputs["input_ids"][:, 1:]
truth_tokens = model.tokenizer.detokenize(truth_inputs["input_ids"])
truth_strs = model.tokenizer.token2str(truth_inputs["input_ids"])
ids = samples["id"]
for pred_token, pred_str, pred_id, truth_token, truth_str, truth_id, id_ in zip(pred_tokens, pred_strs,
pred_ids, truth_tokens,
truth_strs, truth_ids, ids):
pred_id = pred_id.tolist()
truth_id = truth_id.tolist()
shape_diff = len(pred_id) - len(truth_id)
if shape_diff < 0:
pred_id = pred_id + [model.tokenizer.pad_token_id] * (-shape_diff)
else:
truth_id = truth_id + [model.tokenizer.pad_token_id] * shape_diff
pred_id, truth_id = torch.LongTensor(pred_id), torch.LongTensor(truth_id)
mask = torch.logical_or(pred_id != model.tokenizer.pad_token_id, truth_id != model.tokenizer.pad_token_id)
tok_acc = (pred_id == truth_id)[mask].float().mean().item()
this_item = {
"pred_token": pred_token,
"pred_str": pred_str,
"truth_str": truth_str,
"truth_token": truth_token,
"token_acc": tok_acc,
"id": id_
}
results.append(this_item)
return results
def after_evaluation(self, val_result, split_name, epoch, **kwargs):
eval_result_file = self.save_result(
result=val_result,
result_dir=registry.get_path("result_dir"),
filename="{}_epoch{}".format(split_name, epoch),
remove_duplicate="id",
)
if self.report_metric:
metrics = self._report_metrics(
eval_result_file=eval_result_file, split_name=split_name
)
else:
metrics = {"agg_metrics": 0.0}
return metrics
@main_process
def _report_metrics(self, eval_result_file, split_name):
with open(eval_result_file) as f:
results = json.load(f)
edit_dists = []
all_pred_tokens = []
all_truth_tokens = []
all_pred_strs = []
all_truth_strs = []
token_accs = []
for result in results:
pred_token, pred_str, truth_token, truth_str, tok_acc = result["pred_token"], result["pred_str"], result[
"truth_token"], result["truth_str"], result["token_acc"]
if len(truth_str) > 0:
norm_edit_dist = Levenshtein.normalized_distance(pred_str, truth_str)
edit_dists.append(norm_edit_dist)
all_pred_tokens.append(pred_token)
all_truth_tokens.append([truth_token])
all_pred_strs.append(pred_str)
all_truth_strs.append(truth_str)
token_accs.append(tok_acc)
# bleu_score = metrics.bleu_score(all_pred_tokens, all_truth_tokens)
bleu = evaluate.load("bleu", keep_in_memory=True, experiment_id=random.randint(1, 1e8))
bleu_results = bleu.compute(predictions=all_pred_strs, references=all_truth_strs)
bleu_score = bleu_results['bleu']
edit_distance = np.mean(edit_dists)
token_accuracy = np.mean(token_accs)
eval_ret = {"bleu": bleu_score, "edit_distance": edit_distance, "token_accuracy": token_accuracy}
log_stats = {split_name: {k: v for k, v in eval_ret.items()}}
with open(
osp.join(registry.get_path("output_dir"), "evaluate.txt"), "a"
) as f:
f.write(json.dumps(log_stats) + "\n")
coco_res = {k: v for k, v in eval_ret.items()}
# agg_metrics = sum([v for v in eval_ret.values()])
if "edit" in self.agg_metric.lower(): # edit_distance
agg_metrics = (1 - edit_distance) * 100
elif "bleu" in self.agg_metric.lower(): # bleu_score
agg_metrics = bleu_score * 100
elif "token" in self.agg_metric.lower(): # token_accuracy
agg_metrics = token_accuracy * 100
else:
raise ValueError(f"Invalid metrics: '{self.agg_metric}'")
coco_res["agg_metrics"] = agg_metrics
return coco_res
|