Spaces:
Configuration error
Configuration error
import pandas as pd | |
import hashlib | |
from ..smp import * | |
from .dataset_config import dataset_URLs, dataset_md5_dict, DATASET_TYPE | |
from .custom_prompt import CustomPrompt | |
from .matching_util import can_infer | |
def isliststr(s): | |
return (s[0] == '[') and (s[-1] == ']') | |
def check_md5(data_path, dataset): | |
if dataset not in dataset_md5_dict: | |
warnings.warn(f'We do not have an md5 record for dataset {dataset}, skip the md5 check. ') | |
return True | |
assert osp.exists(data_path) | |
with open(data_path, 'rb') as f: | |
hash = hashlib.new('md5') | |
for chunk in iter(lambda: f.read(2**20), b''): | |
hash.update(chunk) | |
if str(hash.hexdigest()) == dataset_md5_dict[dataset]: | |
return True | |
else: | |
warnings.warn('this data file is incomplete, so it needs to be downloaded again.') | |
return False | |
def split_MMMU(msgs): | |
text, images = None, [] | |
for s in msgs: | |
if s['type'] == 'image': | |
images.append(s['value']) | |
elif s['type'] == 'text': | |
assert text is None | |
text = s['value'] | |
text_segs = text.split('<image ') | |
segs = [dict(type='text', value=text_segs[0])] | |
for i, seg in enumerate(text_segs): | |
if i == 0: | |
continue | |
assert istype(seg[0], int) and seg[1] == '>' | |
image_idx = int(seg[0]) - 1 | |
segs.append(dict(type='image', value=images[image_idx])) | |
segs.append(dict(type='text', value=seg[2:])) | |
return segs | |
def MMMU_result_transfer(result_path): | |
res = {} | |
result_data = load(result_path) | |
mcq = result_data['A'].notna() | |
lt = len(result_data) | |
for i in range(lt): | |
line = result_data.iloc[i] | |
if mcq[i]: | |
options = { | |
cand: line[cand] | |
for cand in string.ascii_uppercase | |
if cand in line and not pd.isna(line[cand]) | |
} | |
prediction = line['prediction'] | |
infer_prediction = can_infer(prediction, options) | |
res[line['id']] = infer_prediction | |
else: | |
res[line['id']] = line['prediction'] | |
result_json = result_path.replace('.xlsx', '.json') | |
dump(res, result_json) | |
return result_json | |
class TSVDataset(CustomPrompt): | |
def __init__(self, dataset='MMBench', skip_noimg=True): | |
self.data_root = LMUDataRoot() | |
assert osp.exists(self.data_root) | |
self.dataset = dataset | |
self.dataset_type = DATASET_TYPE(dataset) | |
if dataset in dataset_URLs: | |
url = dataset_URLs[dataset] | |
file_name = url.split('/')[-1] | |
data_path = osp.join(self.data_root, file_name) | |
if osp.exists(data_path) and check_md5(data_path, dataset): | |
pass | |
elif osp.isfile(url): | |
# If url is actually a file path, use it directly | |
data_path = url | |
else: | |
warnings.warn('The dataset tsv is not downloaded') | |
download_file(url, data_path) | |
else: | |
data_path = osp.join(self.data_root, dataset + '.tsv') | |
assert osp.exists(data_path) | |
data = load(data_path) | |
self.skip_noimg = skip_noimg | |
if skip_noimg and 'image' in data: | |
data = data[~pd.isna(data['image'])] | |
# Prompt for Captioning | |
if listinstr(['COCO'], dataset): | |
data['question'] = [( | |
'Please describe this image in general. Directly provide the description, ' | |
'do not include prefix like "This image depicts". ' | |
)] * len(data) | |
data['index'] = [str(x) for x in data['index']] | |
self.meta_only = True | |
if 'image' in data: | |
data['image'] = [str(x) for x in data['image']] | |
image_map = {x: y for x, y in zip(data['index'], data['image'])} | |
for k in image_map: | |
if len(image_map[k]) <= 64: | |
idx = image_map[k] | |
assert idx in image_map and len(image_map[idx]) > 64 | |
image_map[k] = image_map[idx] | |
data['image'] = [ | |
eval(image_map[k]) if isliststr(image_map[k]) else image_map[k] | |
for k in data['index'] | |
] | |
self.meta_only = False | |
if 'image_path' in data: | |
data['image_path'] = [ | |
eval(pths) if isliststr(pths) else pths for pths in data['image_path'] | |
] | |
if np.all([istype(x, int) for x in data['index']]): | |
data['index'] = [int(x) for x in data['index']] | |
self.data = data | |
def __len__(self): | |
return len(self.data) | |
def build_prompt(self, line, dataset=None): | |
if dataset is None: | |
dataset = self.dataset | |
if isinstance(line, int): | |
line = self.data.iloc[line] | |
if self.meta_only: | |
tgt_path = line['image_path'] | |
else: | |
tgt_path = self.dump_image(line, dataset) | |
prompt = line['question'] | |
if DATASET_TYPE(dataset) == 'multi-choice': | |
question = line['question'] | |
options = { | |
cand: line[cand] | |
for cand in string.ascii_uppercase | |
if cand in line and not pd.isna(line[cand]) | |
} | |
options_prompt = 'Options:\n' | |
for key, item in options.items(): | |
options_prompt += f'{key}. {item}\n' | |
hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None | |
prompt = '' | |
if hint is not None: | |
prompt += f'Hint: {hint}\n' | |
prompt += f'Question: {question}\n' | |
if len(options): | |
prompt += options_prompt | |
prompt += 'Please select the correct answer from the options above. \n' | |
elif DATASET_TYPE(dataset) == 'VQA': | |
if listinstr(['ocrvqa', 'textvqa', 'chartqa', 'docvqa'], dataset.lower()): | |
prompt += '\nPlease try to answer the question with short words or phrases if possible\n.' | |
msgs = [] | |
if isinstance(tgt_path, list): | |
msgs.extend([dict(type='image', value=p) for p in tgt_path]) | |
else: | |
msgs = [dict(type='image', value=tgt_path)] | |
msgs.append(dict(type='text', value=prompt)) | |
return msgs | |
def display(self, line): | |
if isinstance(line, int): | |
line = self.data.iloc[line] | |
mmqa_display(line) | |