Spaces:
Sleeping
Sleeping
File size: 5,015 Bytes
08d7644 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
# coding=utf-8
# Copyleft 2019 project LXRT.
import json
import torch
class AnswerTable:
ANS_CONVERT = {
"a man": "man",
"the man": "man",
"a woman": "woman",
"the woman": "woman",
'one': '1',
'two': '2',
'three': '3',
'four': '4',
'five': '5',
'six': '6',
'seven': '7',
'eight': '8',
'nine': '9',
'ten': '10',
'grey': 'gray',
}
def __init__(self, dsets=None):
self.all_ans = json.load(open("data/lxmert/all_ans.json"))
if dsets is not None:
dsets = set(dsets)
# If the answer is used in the dsets
self.anss = [ans['ans'] for ans in self.all_ans if
len(set(ans['dsets']) & dsets) > 0]
else:
self.anss = [ans['ans'] for ans in self.all_ans]
self.ans_set = set(self.anss)
self._id2ans_map = self.anss
self._ans2id_map = {ans: ans_id for ans_id, ans in enumerate(self.anss)}
assert len(self._id2ans_map) == len(self._ans2id_map)
for ans_id, ans in enumerate(self._id2ans_map):
assert self._ans2id_map[ans] == ans_id
def convert_ans(self, ans):
if len(ans) == 0:
return ""
ans = ans.lower()
if ans[-1] == '.':
ans = ans[:-1].strip()
if ans.startswith("a "):
ans = ans[2:].strip()
if ans.startswith("an "):
ans = ans[3:].strip()
if ans.startswith("the "):
ans = ans[4:].strip()
if ans in self.ANS_CONVERT:
ans = self.ANS_CONVERT[ans]
return ans
def ans2id(self, ans):
return self._ans2id_map[ans]
def id2ans(self, ans_id):
return self._id2ans_map[ans_id]
def ans2id_map(self):
return self._ans2id_map.copy()
def id2ans_map(self):
return self._id2ans_map.copy()
def used(self, ans):
return ans in self.ans_set
def all_answers(self):
return self.anss.copy()
@property
def num_answers(self):
return len(self.anss)
def load_lxmert_qa(path, model, label2ans):
"""
Load model weights from lxmert pre-training.
The answers in the fine-tuned QA task (indicated by label2ans)
would also be properly initialized with lxmert pre-trained
QA heads.
:param path: Path to lxmert snapshot.
:param model: LXRT model instance.
:param label2ans: The label2ans dict of fine-tuned QA datasets, like
{0: 'cat', 1: 'dog', ...}
:return:
"""
print("Load QA pre-trained lxmert from %s " % path)
loaded_state_dict = torch.load("%s_LXRT.pth" % path)
model_state_dict = model.state_dict()
# Handle Multi-GPU pre-training --> Single GPU fine-tuning
for key in list(loaded_state_dict.keys()):
loaded_state_dict[key.replace("module.", '')] = loaded_state_dict.pop(key)
# Isolate bert model
bert_state_dict = {}
for key, value in loaded_state_dict.items():
if key.startswith('bert.'):
bert_state_dict[key] = value
# Isolate answer head
answer_state_dict = {}
for key, value in loaded_state_dict.items():
if key.startswith("answer_head."):
answer_state_dict[key.replace('answer_head.', '')] = value
# Do surgery on answer state dict
ans_weight = answer_state_dict['logit_fc.3.weight']
ans_bias = answer_state_dict['logit_fc.3.bias']
import copy
new_answer_weight = copy.deepcopy(model_state_dict['logit_fc.3.weight'])
new_answer_bias = copy.deepcopy(model_state_dict['logit_fc.3.bias'])
answer_table = AnswerTable()
loaded = 0
unload = 0
if type(label2ans) is list:
label2ans = {label: ans for label, ans in enumerate(label2ans)}
for label, ans in label2ans.items():
new_ans = answer_table.convert_ans(ans)
if answer_table.used(new_ans):
ans_id_9500 = answer_table.ans2id(new_ans)
new_answer_weight[label] = ans_weight[ans_id_9500]
new_answer_bias[label] = ans_bias[ans_id_9500]
loaded += 1
else:
new_answer_weight[label] = 0.
new_answer_bias[label] = 0.
unload += 1
print("Loaded %d answers from LXRTQA pre-training and %d not" % (loaded, unload))
print()
answer_state_dict['logit_fc.3.weight'] = new_answer_weight
answer_state_dict['logit_fc.3.bias'] = new_answer_bias
# Load Bert Weights
bert_model_keys = set(model.lxrt_encoder.model.state_dict().keys())
bert_loaded_keys = set(bert_state_dict.keys())
assert len(bert_model_keys - bert_loaded_keys) == 0
model.lxrt_encoder.model.load_state_dict(bert_state_dict, strict=False)
# Load Answer Logic FC Weights
model_keys = set(model.state_dict().keys())
ans_loaded_keys = set(answer_state_dict.keys())
assert len(ans_loaded_keys - model_keys) == 0
model.load_state_dict(answer_state_dict, strict=False)
|