explain-LXMERT / lxmert /src /pretrain /qa_answer_table.py
WwYc's picture
Upload 61 files
08d7644 verified
raw
history blame
5.02 kB
# coding=utf-8
# Copyleft 2019 project LXRT.
import json
import torch
class AnswerTable:
ANS_CONVERT = {
"a man": "man",
"the man": "man",
"a woman": "woman",
"the woman": "woman",
'one': '1',
'two': '2',
'three': '3',
'four': '4',
'five': '5',
'six': '6',
'seven': '7',
'eight': '8',
'nine': '9',
'ten': '10',
'grey': 'gray',
}
def __init__(self, dsets=None):
self.all_ans = json.load(open("data/lxmert/all_ans.json"))
if dsets is not None:
dsets = set(dsets)
# If the answer is used in the dsets
self.anss = [ans['ans'] for ans in self.all_ans if
len(set(ans['dsets']) & dsets) > 0]
else:
self.anss = [ans['ans'] for ans in self.all_ans]
self.ans_set = set(self.anss)
self._id2ans_map = self.anss
self._ans2id_map = {ans: ans_id for ans_id, ans in enumerate(self.anss)}
assert len(self._id2ans_map) == len(self._ans2id_map)
for ans_id, ans in enumerate(self._id2ans_map):
assert self._ans2id_map[ans] == ans_id
def convert_ans(self, ans):
if len(ans) == 0:
return ""
ans = ans.lower()
if ans[-1] == '.':
ans = ans[:-1].strip()
if ans.startswith("a "):
ans = ans[2:].strip()
if ans.startswith("an "):
ans = ans[3:].strip()
if ans.startswith("the "):
ans = ans[4:].strip()
if ans in self.ANS_CONVERT:
ans = self.ANS_CONVERT[ans]
return ans
def ans2id(self, ans):
return self._ans2id_map[ans]
def id2ans(self, ans_id):
return self._id2ans_map[ans_id]
def ans2id_map(self):
return self._ans2id_map.copy()
def id2ans_map(self):
return self._id2ans_map.copy()
def used(self, ans):
return ans in self.ans_set
def all_answers(self):
return self.anss.copy()
@property
def num_answers(self):
return len(self.anss)
def load_lxmert_qa(path, model, label2ans):
"""
Load model weights from lxmert pre-training.
The answers in the fine-tuned QA task (indicated by label2ans)
would also be properly initialized with lxmert pre-trained
QA heads.
:param path: Path to lxmert snapshot.
:param model: LXRT model instance.
:param label2ans: The label2ans dict of fine-tuned QA datasets, like
{0: 'cat', 1: 'dog', ...}
:return:
"""
print("Load QA pre-trained lxmert from %s " % path)
loaded_state_dict = torch.load("%s_LXRT.pth" % path)
model_state_dict = model.state_dict()
# Handle Multi-GPU pre-training --> Single GPU fine-tuning
for key in list(loaded_state_dict.keys()):
loaded_state_dict[key.replace("module.", '')] = loaded_state_dict.pop(key)
# Isolate bert model
bert_state_dict = {}
for key, value in loaded_state_dict.items():
if key.startswith('bert.'):
bert_state_dict[key] = value
# Isolate answer head
answer_state_dict = {}
for key, value in loaded_state_dict.items():
if key.startswith("answer_head."):
answer_state_dict[key.replace('answer_head.', '')] = value
# Do surgery on answer state dict
ans_weight = answer_state_dict['logit_fc.3.weight']
ans_bias = answer_state_dict['logit_fc.3.bias']
import copy
new_answer_weight = copy.deepcopy(model_state_dict['logit_fc.3.weight'])
new_answer_bias = copy.deepcopy(model_state_dict['logit_fc.3.bias'])
answer_table = AnswerTable()
loaded = 0
unload = 0
if type(label2ans) is list:
label2ans = {label: ans for label, ans in enumerate(label2ans)}
for label, ans in label2ans.items():
new_ans = answer_table.convert_ans(ans)
if answer_table.used(new_ans):
ans_id_9500 = answer_table.ans2id(new_ans)
new_answer_weight[label] = ans_weight[ans_id_9500]
new_answer_bias[label] = ans_bias[ans_id_9500]
loaded += 1
else:
new_answer_weight[label] = 0.
new_answer_bias[label] = 0.
unload += 1
print("Loaded %d answers from LXRTQA pre-training and %d not" % (loaded, unload))
print()
answer_state_dict['logit_fc.3.weight'] = new_answer_weight
answer_state_dict['logit_fc.3.bias'] = new_answer_bias
# Load Bert Weights
bert_model_keys = set(model.lxrt_encoder.model.state_dict().keys())
bert_loaded_keys = set(bert_state_dict.keys())
assert len(bert_model_keys - bert_loaded_keys) == 0
model.lxrt_encoder.model.load_state_dict(bert_state_dict, strict=False)
# Load Answer Logic FC Weights
model_keys = set(model.state_dict().keys())
ans_loaded_keys = set(answer_state_dict.keys())
assert len(ans_loaded_keys - model_keys) == 0
model.load_state_dict(answer_state_dict, strict=False)