unimer_demo / unimernet /tasks /unimernet_train.py
wufan's picture
Upload 111 files
18e4b60 verified
import torch
import evaluate
import random
from unimernet.common.registry import registry
from unimernet.tasks.base_task import BaseTask
from unimernet.common.dist_utils import main_process
import os.path as osp
import json
import numpy as np
from torchtext.data import metrics
from rapidfuzz.distance import Levenshtein
@registry.register_task("unimernet_train")
class UniMERNet_Train(BaseTask):
def __init__(self, temperature, do_sample, top_p, evaluate, report_metric=True, agg_metric="edit_distance"):
super(UniMERNet_Train, self).__init__()
self.temperature = temperature
self.do_sample = do_sample
self.top_p = top_p
self.evaluate = evaluate
self.agg_metric = agg_metric
self.report_metric = report_metric
@classmethod
def setup_task(cls, cfg):
run_cfg = cfg.run_cfg
generate_cfg = run_cfg.generate_cfg
temperature = generate_cfg.get('temperature', .2)
do_sample = generate_cfg.get("do_sample", False)
top_p = generate_cfg.get("top_p", 0.95)
evaluate = run_cfg.evaluate
report_metric = run_cfg.get("report_metric", True)
agg_metric = run_cfg.get("agg_metric", "edit_distance")
return cls(
temperature=temperature,
do_sample=do_sample,
top_p=top_p,
evaluate=evaluate,
report_metric=report_metric,
agg_metric=agg_metric,
)
def valid_step(self, model, samples):
results = []
image, text = samples["image"], samples["text_input"]
preds = model.generate(
samples,
temperature=self.temperature,
do_sample=self.do_sample,
top_p=self.top_p
)
pred_tokens = preds["pred_tokens"]
pred_strs = preds["pred_str"]
pred_ids = preds["pred_ids"] # [b, n-1]
truth_inputs = model.tokenizer.tokenize(text)
truth_ids = truth_inputs["input_ids"][:, 1:]
truth_tokens = model.tokenizer.detokenize(truth_inputs["input_ids"])
truth_strs = model.tokenizer.token2str(truth_inputs["input_ids"])
ids = samples["id"]
for pred_token, pred_str, pred_id, truth_token, truth_str, truth_id, id_ in zip(pred_tokens, pred_strs,
pred_ids, truth_tokens,
truth_strs, truth_ids, ids):
pred_id = pred_id.tolist()
truth_id = truth_id.tolist()
shape_diff = len(pred_id) - len(truth_id)
if shape_diff < 0:
pred_id = pred_id + [model.tokenizer.pad_token_id] * (-shape_diff)
else:
truth_id = truth_id + [model.tokenizer.pad_token_id] * shape_diff
pred_id, truth_id = torch.LongTensor(pred_id), torch.LongTensor(truth_id)
mask = torch.logical_or(pred_id != model.tokenizer.pad_token_id, truth_id != model.tokenizer.pad_token_id)
tok_acc = (pred_id == truth_id)[mask].float().mean().item()
this_item = {
"pred_token": pred_token,
"pred_str": pred_str,
"truth_str": truth_str,
"truth_token": truth_token,
"token_acc": tok_acc,
"id": id_
}
results.append(this_item)
return results
def after_evaluation(self, val_result, split_name, epoch, **kwargs):
eval_result_file = self.save_result(
result=val_result,
result_dir=registry.get_path("result_dir"),
filename="{}_epoch{}".format(split_name, epoch),
remove_duplicate="id",
)
if self.report_metric:
metrics = self._report_metrics(
eval_result_file=eval_result_file, split_name=split_name
)
else:
metrics = {"agg_metrics": 0.0}
return metrics
@main_process
def _report_metrics(self, eval_result_file, split_name):
with open(eval_result_file) as f:
results = json.load(f)
edit_dists = []
all_pred_tokens = []
all_truth_tokens = []
all_pred_strs = []
all_truth_strs = []
token_accs = []
for result in results:
pred_token, pred_str, truth_token, truth_str, tok_acc = result["pred_token"], result["pred_str"], result[
"truth_token"], result["truth_str"], result["token_acc"]
if len(truth_str) > 0:
norm_edit_dist = Levenshtein.normalized_distance(pred_str, truth_str)
edit_dists.append(norm_edit_dist)
all_pred_tokens.append(pred_token)
all_truth_tokens.append([truth_token])
all_pred_strs.append(pred_str)
all_truth_strs.append(truth_str)
token_accs.append(tok_acc)
# bleu_score = metrics.bleu_score(all_pred_tokens, all_truth_tokens)
bleu = evaluate.load("bleu", keep_in_memory=True, experiment_id=random.randint(1, 1e8))
bleu_results = bleu.compute(predictions=all_pred_strs, references=all_truth_strs)
bleu_score = bleu_results['bleu']
edit_distance = np.mean(edit_dists)
token_accuracy = np.mean(token_accs)
eval_ret = {"bleu": bleu_score, "edit_distance": edit_distance, "token_accuracy": token_accuracy}
log_stats = {split_name: {k: v for k, v in eval_ret.items()}}
with open(
osp.join(registry.get_path("output_dir"), "evaluate.txt"), "a"
) as f:
f.write(json.dumps(log_stats) + "\n")
coco_res = {k: v for k, v in eval_ret.items()}
# agg_metrics = sum([v for v in eval_ret.values()])
if "edit" in self.agg_metric.lower(): # edit_distance
agg_metrics = (1 - edit_distance) * 100
elif "bleu" in self.agg_metric.lower(): # bleu_score
agg_metrics = bleu_score * 100
elif "token" in self.agg_metric.lower(): # token_accuracy
agg_metrics = token_accuracy * 100
else:
raise ValueError(f"Invalid metrics: '{self.agg_metric}'")
coco_res["agg_metrics"] = agg_metrics
return coco_res