# Copyright (c) OpenMMLab. All rights reserved. # Partly adopted from https://github.com/GT-Vision-Lab/VQA # Copyright (c) 2014, Aishwarya Agrawal import re from vlmeval.smp import * from typing import Optional from functools import partial def _process_digit_article(inText): outText = [] tempText = inText.lower().split() articles = ['a', 'an', 'the'] manualMap = { 'none': '0', 'zero': '0', 'one': '1', 'two': '2', 'three': '3', 'four': '4', 'five': '5', 'six': '6', 'seven': '7', 'eight': '8', 'nine': '9', 'ten': '10', } contractions = { 'aint': "ain't", 'arent': "aren't", 'cant': "can't", 'couldve': "could've", 'couldnt': "couldn't", "couldn'tve": "couldn't've", "couldnt've": "couldn't've", 'didnt': "didn't", 'doesnt': "doesn't", 'dont': "don't", 'hadnt': "hadn't", "hadnt've": "hadn't've", "hadn'tve": "hadn't've", 'hasnt': "hasn't", 'havent': "haven't", 'hed': "he'd", "hed've": "he'd've", "he'dve": "he'd've", 'hes': "he's", 'howd': "how'd", 'howll': "how'll", 'hows': "how's", "Id've": "I'd've", "I'dve": "I'd've", 'Im': "I'm", 'Ive': "I've", 'isnt': "isn't", 'itd': "it'd", "itd've": "it'd've", "it'dve": "it'd've", 'itll': "it'll", "let's": "let's", 'maam': "ma'am", 'mightnt': "mightn't", "mightnt've": "mightn't've", "mightn'tve": "mightn't've", 'mightve': "might've", 'mustnt': "mustn't", 'mustve': "must've", 'neednt': "needn't", 'notve': "not've", 'oclock': "o'clock", 'oughtnt': "oughtn't", "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": "'ow's'at", 'shant': "shan't", "shed've": "she'd've", "she'dve": "she'd've", "she's": "she's", 'shouldve': "should've", 'shouldnt': "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": "shouldn't've", "somebody'd": 'somebodyd', "somebodyd've": "somebody'd've", "somebody'dve": "somebody'd've", 'somebodyll': "somebody'll", 'somebodys': "somebody's", 'someoned': "someone'd", "someoned've": "someone'd've", "someone'dve": "someone'd've", 'someonell': "someone'll", 'someones': "someone's", 'somethingd': "something'd", "somethingd've": "something'd've", "something'dve": "something'd've", 'somethingll': "something'll", 'thats': "that's", 'thered': "there'd", "thered've": "there'd've", "there'dve": "there'd've", 'therere': "there're", 'theres': "there's", 'theyd': "they'd", "theyd've": "they'd've", "they'dve": "they'd've", 'theyll': "they'll", 'theyre': "they're", 'theyve': "they've", 'twas': "'twas", 'wasnt': "wasn't", "wed've": "we'd've", "we'dve": "we'd've", 'weve': "we've", 'werent': "weren't", 'whatll': "what'll", 'whatre': "what're", 'whats': "what's", 'whatve': "what've", 'whens': "when's", 'whered': "where'd", 'wheres': "where's", 'whereve': "where've", 'whod': "who'd", "whod've": "who'd've", "who'dve": "who'd've", 'wholl': "who'll", 'whos': "who's", 'whove': "who've", 'whyll': "why'll", 'whyre': "why're", 'whys': "why's", 'wont': "won't", 'wouldve': "would've", 'wouldnt': "wouldn't", "wouldnt've": "wouldn't've", "wouldn'tve": "wouldn't've", 'yall': "y'all", "yall'll": "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", 'youd': "you'd", "youd've": "you'd've", "you'dve": "you'd've", 'youll': "you'll", 'youre': "you're", 'youve': "you've", } for word in tempText: word = manualMap.setdefault(word, word) if word not in articles: outText.append(word) for wordId, word in enumerate(outText): if word in contractions: outText[wordId] = contractions[word] outText = ' '.join(outText) return outText def hit_calculate(result, dataset_name, anls_threshold=0.5): if listinstr(['TextVQA'], dataset_name): return [np.mean(x['match']) for x in result] elif listinstr(['DocVQA', 'InfoVQA'], dataset_name): # return [1 - np.min(x['match']) >= anls_threshold for x in result] return [0.0 if 1 - np.min(x['match']) < anls_threshold else 1 - np.min(x['match']) for x in result] elif listinstr(['ChartQA', 'OCRVQA'], dataset_name): return [np.max(x['match']) for x in result] else: # default using vqa_score to calculate score return [np.mean(x['match']) for x in result] # https://github.com/google-research/pix2struct/blob/main/pix2struct/metrics.py#L81 def relaxed_correctness(target: str, prediction: str, max_relative_change: float = 0.05) -> bool: """Calculates relaxed correctness. The correctness tolerates certain error ratio defined by max_relative_change. See https://arxiv.org/pdf/2203.10244.pdf, end of section 5.1: “Following Methani et al. (2020), we use a relaxed accuracy measure for the numeric answers to allow a minor inaccuracy that may result from the automatic data extraction process. We consider an answer to be correct if it is within 5% of the gold answer. For non-numeric answers, we still need an exact match to consider an answer to be correct.” Args: target: Target string. prediction: Predicted string. max_relative_change: Maximum relative change. Returns: Whether the prediction was correct given the specified tolerance. """ def _to_float(text: str) -> Optional[float]: try: if text.endswith('%'): # Convert percentages to floats. return float(text.rstrip('%')) / 100.0 else: return float(text) except ValueError: return None prediction = str(prediction) target = str(target) prediction_float = _to_float(prediction) target_float = _to_float(target) if prediction_float is not None and target_float: relative_change = abs(prediction_float - target_float) / abs(target_float) return relative_change <= max_relative_change else: return prediction.lower() == target.lower() def levenshtein_distance(s1, s2): if len(s1) > len(s2): s1, s2 = s2, s1 distances = range(len(s1) + 1) for i2, c2 in enumerate(s2): distances_ = [i2 + 1] for i1, c1 in enumerate(s1): if c1 == c2: distances_.append(distances[i1]) else: distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1]))) distances = distances_ return distances[-1] def anls_compute(groundtruth, prediction): gt_answer = ' '.join(groundtruth.strip().lower().split()) det_answer = ' '.join(prediction.strip().lower().split()) dist = levenshtein_distance(gt_answer, det_answer) length = max(len(groundtruth.upper()), len(prediction.upper())) values = 0.0 if length == 0 else float(dist) / float(length) return values def process_answer(answer): answer = answer.replace('\n', ' ') answer = answer.replace('\t', ' ') answer = answer.strip() answer = process_punctuation(answer) answer = _process_digit_article(answer) return answer def process_line(line, method='vqa_score'): ret = {} if istype(line['answer'], list): answers = eval(line['answer']) else: answers = [line['answer']] if method == 'vqa_score': ret['gt'] = [process_answer(x) for x in answers] ret['pred'] = process_answer(line['prediction']) ret['match'] = [] for current_idx, gtAnsDatum in enumerate(ret['gt']): otherGTAns = [ item for ret_gt_idx, item in enumerate(ret['gt']) if ret_gt_idx != current_idx ] matchingAns = [ item for item in otherGTAns if item == ret['pred'] ] acc = min(1, float(len(matchingAns)) / 3) ret['match'].append(acc) elif method == 'anls': ret['gt'] = answers ret['pred'] = line['prediction'] ret['match'] = [anls_compute(x, ret['pred']) for x in ret['gt']] elif method == 'relaxed_accuracy': ret['gt'] = answers ret['pred'] = line['prediction'].strip() ret['match'] = [relaxed_correctness(ret['pred'], x) for x in ret['gt']] elif method == 'accuracy': ret['gt'] = answers ret['pred'] = line['prediction'].strip() ret['match'] = [(1.0 if (x.strip().lower() == ret['pred'].strip().lower()) else 0.0) for x in ret['gt']] else: # default using vqa_score to calculate score ret['gt'] = [process_answer(x) for x in answers] ret['pred'] = process_answer(line['prediction']) ret['match'] = [x == ret['pred'] for x in ret['gt']] return ret def VQAEval(eval_file, dataset_name, **kwargs): logger = get_logger('Evaluation') data = load(eval_file) assert 'answer' in data and 'prediction' in data data['prediction'] = [str(x) for x in data['prediction']] data['answer'] = [str(x) for x in data['answer']] lt = len(data) pool = mp.Pool(16) lines = [data.iloc[i] for i in range(lt)] if listinstr(['TextVQA'], dataset_name): res = pool.map(partial(process_line, method='vqa_score'), lines) elif listinstr(['ChartQA'], dataset_name): res = pool.map(partial(process_line, method='relaxed_accuracy'), lines) elif listinstr(['OCRVQA'], dataset_name): res = pool.map(partial(process_line, method='accuracy'), lines) elif listinstr(['DocVQA', 'InfoVQA'], dataset_name): res = pool.map(partial(process_line, method='anls'), lines) else: # default using vqa_score to calculate score res = pool.map(process_line, lines) # [np.mean(x['match']) >= full_score_weight for x in res] hit = hit_calculate(res, dataset_name) ret = dict() if 'split' in data: splits = set(data['split']) for sp in splits: sub = [r for l, r in zip(lines, res) if l['split'] == sp] # [np.mean(x['match']) >= full_score_weight for x in sub] hit = hit_calculate(sub, dataset_name) ret[sp] = np.mean(hit) * 100 sub = [r for l, r in zip(lines, res)] hit = hit_calculate(sub, dataset_name) ret['Overall'] = np.mean(hit) * 100 else: ret['Overall'] = np.mean(hit) * 100 if 'category' in data: cates = list(set(data['category'])) cates.sort() for c in cates: sub = [r for l, r in zip(lines, res) if l['category'] == c] # [np.mean(x['match']) >= full_score_weight for x in sub] hit = hit_calculate(sub, dataset_name) ret[c] = np.mean(hit) * 100 ret = d2df(ret) ret.round(2) suffix = eval_file.split('.')[-1] result_file = eval_file.replace(f'.{suffix}', '_acc.csv') logger.info(f'VQA Eval Finished. Saved to {result_file}. ') logger.info(ret) dump(ret, result_file)