import pandas as pd import hashlib from ..smp import * from .dataset_config import dataset_URLs, dataset_md5_dict, DATASET_TYPE from .custom_prompt import CustomPrompt from .matching_util import can_infer def isliststr(s): return (s[0] == '[') and (s[-1] == ']') def check_md5(data_path, dataset): if dataset not in dataset_md5_dict: warnings.warn(f'We do not have an md5 record for dataset {dataset}, skip the md5 check. ') return True assert osp.exists(data_path) with open(data_path, 'rb') as f: hash = hashlib.new('md5') for chunk in iter(lambda: f.read(2**20), b''): hash.update(chunk) if str(hash.hexdigest()) == dataset_md5_dict[dataset]: return True else: warnings.warn('this data file is incomplete, so it needs to be downloaded again.') return False def split_MMMU(msgs): text, images = None, [] for s in msgs: if s['type'] == 'image': images.append(s['value']) elif s['type'] == 'text': assert text is None text = s['value'] text_segs = text.split('' image_idx = int(seg[0]) - 1 segs.append(dict(type='image', value=images[image_idx])) segs.append(dict(type='text', value=seg[2:])) return segs def MMMU_result_transfer(result_path): res = {} result_data = load(result_path) mcq = result_data['A'].notna() lt = len(result_data) for i in range(lt): line = result_data.iloc[i] if mcq[i]: options = { cand: line[cand] for cand in string.ascii_uppercase if cand in line and not pd.isna(line[cand]) } prediction = line['prediction'] infer_prediction = can_infer(prediction, options) res[line['id']] = infer_prediction else: res[line['id']] = line['prediction'] result_json = result_path.replace('.xlsx', '.json') dump(res, result_json) return result_json class TSVDataset(CustomPrompt): def __init__(self, dataset='MMBench', skip_noimg=True): self.data_root = LMUDataRoot() assert osp.exists(self.data_root) self.dataset = dataset self.dataset_type = DATASET_TYPE(dataset) if dataset in dataset_URLs: url = dataset_URLs[dataset] file_name = url.split('/')[-1] data_path = osp.join(self.data_root, file_name) if osp.exists(data_path) and check_md5(data_path, dataset): pass elif osp.isfile(url): # If url is actually a file path, use it directly data_path = url else: warnings.warn('The dataset tsv is not downloaded') download_file(url, data_path) else: data_path = osp.join(self.data_root, dataset + '.tsv') assert osp.exists(data_path) data = load(data_path) self.skip_noimg = skip_noimg if skip_noimg and 'image' in data: data = data[~pd.isna(data['image'])] # Prompt for Captioning if listinstr(['COCO'], dataset): data['question'] = [( 'Please describe this image in general. Directly provide the description, ' 'do not include prefix like "This image depicts". ' )] * len(data) data['index'] = [str(x) for x in data['index']] self.meta_only = True if 'image' in data: data['image'] = [str(x) for x in data['image']] image_map = {x: y for x, y in zip(data['index'], data['image'])} for k in image_map: if len(image_map[k]) <= 64: idx = image_map[k] assert idx in image_map and len(image_map[idx]) > 64 image_map[k] = image_map[idx] data['image'] = [ eval(image_map[k]) if isliststr(image_map[k]) else image_map[k] for k in data['index'] ] self.meta_only = False if 'image_path' in data: data['image_path'] = [ eval(pths) if isliststr(pths) else pths for pths in data['image_path'] ] if np.all([istype(x, int) for x in data['index']]): data['index'] = [int(x) for x in data['index']] self.data = data def __len__(self): return len(self.data) def build_prompt(self, line, dataset=None): if dataset is None: dataset = self.dataset if isinstance(line, int): line = self.data.iloc[line] if self.meta_only: tgt_path = line['image_path'] else: tgt_path = self.dump_image(line, dataset) prompt = line['question'] if DATASET_TYPE(dataset) == 'multi-choice': question = line['question'] options = { cand: line[cand] for cand in string.ascii_uppercase if cand in line and not pd.isna(line[cand]) } options_prompt = 'Options:\n' for key, item in options.items(): options_prompt += f'{key}. {item}\n' hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None prompt = '' if hint is not None: prompt += f'Hint: {hint}\n' prompt += f'Question: {question}\n' if len(options): prompt += options_prompt prompt += 'Please select the correct answer from the options above. \n' elif DATASET_TYPE(dataset) == 'VQA': if listinstr(['ocrvqa', 'textvqa', 'chartqa', 'docvqa'], dataset.lower()): prompt += '\nPlease try to answer the question with short words or phrases if possible\n.' msgs = [] if isinstance(tgt_path, list): msgs.extend([dict(type='image', value=p) for p in tgt_path]) else: msgs = [dict(type='image', value=tgt_path)] msgs.append(dict(type='text', value=prompt)) return msgs def display(self, line): if isinstance(line, int): line = self.data.iloc[line] mmqa_display(line)