botcon commited on
Commit
bfaee96
·
1 Parent(s): c179832

Upload meta.py

Browse files
Files changed (1) hide show
  1. meta.py +123 -0
meta.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ from transformers import AutoTokenizer, BertForSequenceClassification, PreTrainedModel, PretrainedConfig, get_scheduler
4
+ from transformers.modeling_outputs import SequenceClassifierOutput
5
+ from torch.nn import CrossEntropyLoss
6
+ from torch.optim import AdamW
7
+ from LUKE_pipe import generate
8
+ from datasets import load_dataset
9
+ from accelerate import Accelerator
10
+ from tqdm import tqdm
11
+
12
+ MAX_BEAM = 10
13
+ tf32 = True
14
+ torch.backends.cuda.matmul.allow_tf32 = tf32
15
+ torch.backends.cudnn.allow_tf32 = tf32
16
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
17
+
18
+ class ClassifierAdapter(nn.Module):
19
+ def __init__(self, l1=3):
20
+ super().__init__()
21
+ self.linear1 = nn.Linear(l1, 1)
22
+ self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
23
+ self.bert = BertForSequenceClassification.from_pretrained("botcon/right_span_bert")
24
+ self.relu = nn.ReLU()
25
+
26
+ def forward(self, questions, answers, logits):
27
+ beam_size = len(answers[0])
28
+ samples = len(questions)
29
+ questions = [question for _ in range(len(answers[0])) for question in questions]
30
+ answers = [answer for beam in answers for answer in beam]
31
+ input = self.tokenizer(
32
+ questions,
33
+ answers,
34
+ padding="max_length",
35
+ return_tensors="pt"
36
+ ).to(device)
37
+ bert_logits = self.bert(**input).logits
38
+ bert_logits = bert_logits.reshape(samples, beam_size, 2)
39
+ logits = torch.FloatTensor(logits).to(device).unsqueeze(-1)
40
+ logits = torch.cat((logits, bert_logits), dim=-1)
41
+ logits = self.relu(logits)
42
+ out = torch.squeeze(self.linear1(logits), dim=-1)
43
+ return out
44
+
45
+ class HuggingWrapper(PreTrainedModel):
46
+ config_class = PretrainedConfig()
47
+ def __init__(self, config):
48
+ super().__init__(config)
49
+ self.model = ClassifierAdapter()
50
+
51
+ def forward(self, **kwargs):
52
+ labels = kwargs.pop("labels")
53
+ output = self.model(**kwargs)
54
+ loss_fn = CrossEntropyLoss(ignore_index=MAX_BEAM)
55
+ loss = loss_fn(output, labels)
56
+ return SequenceClassifierOutput(logits=output, loss=loss)
57
+
58
+ accelerator = Accelerator(mixed_precision="fp16")
59
+ model = HuggingWrapper.from_pretrained("botcon/special_bert").to(device)
60
+ optimizer = AdamW(model.parameters())
61
+ model, optimizer = accelerator.prepare(model, optimizer)
62
+ batch_size = 2
63
+ raw_datasets = load_dataset("squad")
64
+ raw_train = raw_datasets["train"]
65
+ num_updates = len(raw_train) // batch_size
66
+ num_epoch = 2
67
+ num_training_steps = num_updates * num_epoch
68
+ lr_scheduler = get_scheduler(
69
+ "linear",
70
+ optimizer=optimizer,
71
+ num_warmup_steps=0,
72
+ num_training_steps=num_training_steps,
73
+ )
74
+
75
+ progress_bar = tqdm(range(num_training_steps))
76
+
77
+ for epoch in range(num_epoch):
78
+ start = 0
79
+ end = batch_size
80
+ steps = 0
81
+ cumu_loss = 0
82
+ training_data = raw_train
83
+ model.train()
84
+ while start < len(training_data):
85
+ optimizer.zero_grad()
86
+ batch_data = raw_train.select(range(start, min(end, len(raw_train))))
87
+ with torch.no_grad():
88
+ res = generate(batch_data)
89
+ prediction = []
90
+ predicted_logit = []
91
+ labels = []
92
+ for i in range(len(res)):
93
+ x = res[i]
94
+ ground_answer = batch_data["answers"][i]["text"][0]
95
+ predicted_text = x["prediction_text"]
96
+ found = False
97
+ for k in range(len(predicted_text)):
98
+ if predicted_text[k] == ground_answer:
99
+ labels.append(k)
100
+ found = True
101
+ break
102
+ if not found:
103
+ labels.append(MAX_BEAM)
104
+ prediction.append(predicted_text)
105
+ predicted_logit.append(x["logits"])
106
+ labels = torch.LongTensor(labels).to(device)
107
+ classifier_out = model(questions=batch_data["question"] , answers=prediction, logits=predicted_logit, labels=labels)
108
+ loss = classifier_out.loss
109
+ if not torch.isnan(loss).item():
110
+ cumu_loss += loss.item()
111
+ steps += 1
112
+ accelerator.backward(loss)
113
+ optimizer.step()
114
+ lr_scheduler.step()
115
+ progress_bar.update(1)
116
+ start += batch_size
117
+ end += batch_size
118
+ # every 100 steps
119
+ if steps % 100 == 0:
120
+ print("Cumu loss: {}".format(cumu_loss / 100))
121
+ cumu_loss = 0
122
+
123
+ model.push_to_hub("Adapter Bert")