Spaces:
Sleeping
Sleeping
# coding=utf-8 | |
# Copyleft 2019 project LXRT. | |
import json | |
import torch | |
class AnswerTable: | |
ANS_CONVERT = { | |
"a man": "man", | |
"the man": "man", | |
"a woman": "woman", | |
"the woman": "woman", | |
'one': '1', | |
'two': '2', | |
'three': '3', | |
'four': '4', | |
'five': '5', | |
'six': '6', | |
'seven': '7', | |
'eight': '8', | |
'nine': '9', | |
'ten': '10', | |
'grey': 'gray', | |
} | |
def __init__(self, dsets=None): | |
self.all_ans = json.load(open("data/lxmert/all_ans.json")) | |
if dsets is not None: | |
dsets = set(dsets) | |
# If the answer is used in the dsets | |
self.anss = [ans['ans'] for ans in self.all_ans if | |
len(set(ans['dsets']) & dsets) > 0] | |
else: | |
self.anss = [ans['ans'] for ans in self.all_ans] | |
self.ans_set = set(self.anss) | |
self._id2ans_map = self.anss | |
self._ans2id_map = {ans: ans_id for ans_id, ans in enumerate(self.anss)} | |
assert len(self._id2ans_map) == len(self._ans2id_map) | |
for ans_id, ans in enumerate(self._id2ans_map): | |
assert self._ans2id_map[ans] == ans_id | |
def convert_ans(self, ans): | |
if len(ans) == 0: | |
return "" | |
ans = ans.lower() | |
if ans[-1] == '.': | |
ans = ans[:-1].strip() | |
if ans.startswith("a "): | |
ans = ans[2:].strip() | |
if ans.startswith("an "): | |
ans = ans[3:].strip() | |
if ans.startswith("the "): | |
ans = ans[4:].strip() | |
if ans in self.ANS_CONVERT: | |
ans = self.ANS_CONVERT[ans] | |
return ans | |
def ans2id(self, ans): | |
return self._ans2id_map[ans] | |
def id2ans(self, ans_id): | |
return self._id2ans_map[ans_id] | |
def ans2id_map(self): | |
return self._ans2id_map.copy() | |
def id2ans_map(self): | |
return self._id2ans_map.copy() | |
def used(self, ans): | |
return ans in self.ans_set | |
def all_answers(self): | |
return self.anss.copy() | |
def num_answers(self): | |
return len(self.anss) | |
def load_lxmert_qa(path, model, label2ans): | |
""" | |
Load model weights from lxmert pre-training. | |
The answers in the fine-tuned QA task (indicated by label2ans) | |
would also be properly initialized with lxmert pre-trained | |
QA heads. | |
:param path: Path to lxmert snapshot. | |
:param model: LXRT model instance. | |
:param label2ans: The label2ans dict of fine-tuned QA datasets, like | |
{0: 'cat', 1: 'dog', ...} | |
:return: | |
""" | |
print("Load QA pre-trained lxmert from %s " % path) | |
loaded_state_dict = torch.load("%s_LXRT.pth" % path) | |
model_state_dict = model.state_dict() | |
# Handle Multi-GPU pre-training --> Single GPU fine-tuning | |
for key in list(loaded_state_dict.keys()): | |
loaded_state_dict[key.replace("module.", '')] = loaded_state_dict.pop(key) | |
# Isolate bert model | |
bert_state_dict = {} | |
for key, value in loaded_state_dict.items(): | |
if key.startswith('bert.'): | |
bert_state_dict[key] = value | |
# Isolate answer head | |
answer_state_dict = {} | |
for key, value in loaded_state_dict.items(): | |
if key.startswith("answer_head."): | |
answer_state_dict[key.replace('answer_head.', '')] = value | |
# Do surgery on answer state dict | |
ans_weight = answer_state_dict['logit_fc.3.weight'] | |
ans_bias = answer_state_dict['logit_fc.3.bias'] | |
import copy | |
new_answer_weight = copy.deepcopy(model_state_dict['logit_fc.3.weight']) | |
new_answer_bias = copy.deepcopy(model_state_dict['logit_fc.3.bias']) | |
answer_table = AnswerTable() | |
loaded = 0 | |
unload = 0 | |
if type(label2ans) is list: | |
label2ans = {label: ans for label, ans in enumerate(label2ans)} | |
for label, ans in label2ans.items(): | |
new_ans = answer_table.convert_ans(ans) | |
if answer_table.used(new_ans): | |
ans_id_9500 = answer_table.ans2id(new_ans) | |
new_answer_weight[label] = ans_weight[ans_id_9500] | |
new_answer_bias[label] = ans_bias[ans_id_9500] | |
loaded += 1 | |
else: | |
new_answer_weight[label] = 0. | |
new_answer_bias[label] = 0. | |
unload += 1 | |
print("Loaded %d answers from LXRTQA pre-training and %d not" % (loaded, unload)) | |
print() | |
answer_state_dict['logit_fc.3.weight'] = new_answer_weight | |
answer_state_dict['logit_fc.3.bias'] = new_answer_bias | |
# Load Bert Weights | |
bert_model_keys = set(model.lxrt_encoder.model.state_dict().keys()) | |
bert_loaded_keys = set(bert_state_dict.keys()) | |
assert len(bert_model_keys - bert_loaded_keys) == 0 | |
model.lxrt_encoder.model.load_state_dict(bert_state_dict, strict=False) | |
# Load Answer Logic FC Weights | |
model_keys = set(model.state_dict().keys()) | |
ans_loaded_keys = set(answer_state_dict.keys()) | |
assert len(ans_loaded_keys - model_keys) == 0 | |
model.load_state_dict(answer_state_dict, strict=False) | |