botcon commited on
Commit
7b28a10
1 Parent(s): 1d59f45

Upload LukeQuestionAnswering.py

Browse files
Files changed (1) hide show
  1. LukeQuestionAnswering.py +349 -0
LukeQuestionAnswering.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import LukePreTrainedModel, LukeModel, AutoTokenizer, TrainingArguments, default_data_collator, Trainer, LukeForQuestionAnswering
2
+ from transformers.modeling_outputs import ModelOutput
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import torch
6
+ from dataclasses import dataclass
7
+ from datasets import load_dataset
8
+ from torch import nn
9
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
10
+
11
+ PEFT = False
12
+ repo_name = "LUKE_squad_finetuned_qa"
13
+ tf32 = True
14
+ fp16= True
15
+
16
+ torch.backends.cuda.matmul.allow_tf32 = tf32
17
+ torch.backends.cudnn.allow_tf32 = tf32
18
+
19
+ if tf32:
20
+ repo_name += "_tf32"
21
+
22
+ # https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/luke/modeling_luke.py#L319-L353
23
+ # Taken from HF repository, easier to include additional features -- Currently identical to LukeForQuestionAnswering by HF
24
+
25
+ @dataclass
26
+ class LukeQuestionAnsweringModelOutput(ModelOutput):
27
+ """
28
+ Outputs of question answering models.
29
+
30
+
31
+ Args:
32
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
33
+ Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
34
+ start_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
35
+ Span-start scores (before SoftMax).
36
+ end_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
37
+ Span-end scores (before SoftMax).
38
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
39
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
40
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
41
+
42
+
43
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
44
+ entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
45
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
46
+ shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
47
+ layer plus the initial entity embedding outputs.
48
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
49
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
50
+ sequence_length)`.
51
+
52
+
53
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
54
+ heads.
55
+ """
56
+
57
+
58
+ loss: Optional[torch.FloatTensor] = None
59
+ start_logits: torch.FloatTensor = None
60
+ end_logits: torch.FloatTensor = None
61
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
62
+ entity_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
63
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
64
+
65
+ class AugmentedLukeForQuestionAnswering(LukePreTrainedModel):
66
+ def __init__(self, config):
67
+ super().__init__(config)
68
+
69
+ # This is 2.
70
+ self.num_labels = config.num_labels
71
+
72
+ self.luke = LukeModel(config, add_pooling_layer=False)
73
+
74
+ '''
75
+ Any improvement to the model are expected here. Additional features, anything...
76
+ '''
77
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
78
+
79
+
80
+ # Initialize weights and apply final processing
81
+ self.post_init()
82
+
83
+ def forward(
84
+ self,
85
+ input_ids: Optional[torch.LongTensor] = None,
86
+ attention_mask: Optional[torch.FloatTensor] = None,
87
+ token_type_ids: Optional[torch.LongTensor] = None,
88
+ position_ids: Optional[torch.FloatTensor] = None,
89
+ entity_ids: Optional[torch.LongTensor] = None,
90
+ entity_attention_mask: Optional[torch.FloatTensor] = None,
91
+ entity_token_type_ids: Optional[torch.LongTensor] = None,
92
+ entity_position_ids: Optional[torch.LongTensor] = None,
93
+ head_mask: Optional[torch.FloatTensor] = None,
94
+ inputs_embeds: Optional[torch.FloatTensor] = None,
95
+ start_positions: Optional[torch.LongTensor] = None,
96
+ end_positions: Optional[torch.LongTensor] = None,
97
+ output_attentions: Optional[bool] = None,
98
+ output_hidden_states: Optional[bool] = None,
99
+ return_dict: Optional[bool] = None,
100
+ ) -> Union[Tuple, LukeQuestionAnsweringModelOutput]:
101
+
102
+ r"""
103
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
104
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
105
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
106
+ are not taken into account for computing the loss.
107
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
108
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
109
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
110
+ are not taken into account for computing the loss.
111
+ """
112
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
113
+
114
+
115
+ outputs = self.luke(
116
+ input_ids=input_ids,
117
+ attention_mask=attention_mask,
118
+ token_type_ids=token_type_ids,
119
+ position_ids=position_ids,
120
+ entity_ids=entity_ids,
121
+ entity_attention_mask=entity_attention_mask,
122
+ entity_token_type_ids=entity_token_type_ids,
123
+ entity_position_ids=entity_position_ids,
124
+ head_mask=head_mask,
125
+ inputs_embeds=inputs_embeds,
126
+ output_attentions=output_attentions,
127
+ output_hidden_states=output_hidden_states,
128
+ return_dict=True,
129
+ )
130
+
131
+
132
+ sequence_output = outputs.last_hidden_state
133
+
134
+
135
+ logits = self.qa_outputs(sequence_output)
136
+ start_logits, end_logits = logits.split(1, dim=-1)
137
+ start_logits = start_logits.squeeze(-1)
138
+ end_logits = end_logits.squeeze(-1)
139
+
140
+
141
+ total_loss = None
142
+ if start_positions is not None and end_positions is not None:
143
+ # If we are on multi-GPU, split add a dimension
144
+ if len(start_positions.size()) > 1:
145
+ start_positions = start_positions.squeeze(-1)
146
+ if len(end_positions.size()) > 1:
147
+ end_positions = end_positions.squeeze(-1)
148
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
149
+ ignored_index = start_logits.size(1)
150
+ start_positions.clamp_(0, ignored_index)
151
+ end_positions.clamp_(0, ignored_index)
152
+
153
+
154
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
155
+ start_loss = loss_fct(start_logits, start_positions)
156
+ end_loss = loss_fct(end_logits, end_positions)
157
+ total_loss = (start_loss + end_loss) / 2
158
+
159
+
160
+ if not return_dict:
161
+ return tuple(
162
+ v
163
+ for v in [
164
+ total_loss,
165
+ start_logits,
166
+ end_logits,
167
+ outputs.hidden_states,
168
+ outputs.entity_hidden_states,
169
+ outputs.attentions,
170
+ ]
171
+ if v is not None
172
+ )
173
+
174
+
175
+ return LukeQuestionAnsweringModelOutput(
176
+ loss=total_loss,
177
+ start_logits=start_logits,
178
+ end_logits=end_logits,
179
+ hidden_states=outputs.hidden_states,
180
+ entity_hidden_states=outputs.entity_hidden_states,
181
+ attentions=outputs.attentions,
182
+ )
183
+
184
+ if __name__ == "__main__":
185
+ base_luke = "studio-ousia/luke-base"
186
+
187
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
188
+
189
+ # Luke does not have a FastTokenizer
190
+ # Work-around for FastTokenizer - RoBERTa and LUKE share the same subword vocab, and we are not using entities functions of LUKE-tokenizer anyways
191
+ tokenizer = AutoTokenizer.from_pretrained("roberta-base")
192
+
193
+ # tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
194
+ model = AugmentedLukeForQuestionAnswering.from_pretrained(base_luke).to(device)
195
+
196
+ raw_datasets = load_dataset("squad")
197
+
198
+ # not exactly hyperparameters
199
+ max_length = 384
200
+ stride = 128
201
+ batch_size = 8
202
+
203
+ def preprocess_training_examples(examples):
204
+
205
+ questions = [q.strip() for q in examples["question"]]
206
+ inputs = tokenizer(
207
+ questions,
208
+ examples["context"],
209
+ max_length=max_length,
210
+ truncation="only_second",
211
+ stride=stride,
212
+ return_overflowing_tokens=True,
213
+ return_offsets_mapping=True,
214
+ padding="max_length",
215
+ )
216
+
217
+ offset_mapping = inputs.pop("offset_mapping")
218
+ sample_map = inputs.pop("overflow_to_sample_mapping")
219
+ answers = examples["answers"]
220
+ start_positions = []
221
+ end_positions = []
222
+
223
+ for i, offset in enumerate(offset_mapping):
224
+ sample_idx = sample_map[i]
225
+ answer = answers[sample_idx]
226
+ start_char = answer["answer_start"][0]
227
+ end_char = answer["answer_start"][0] + len(answer["text"][0])
228
+ sequence_ids = inputs.sequence_ids(i)
229
+
230
+ # Find the start and end of the context
231
+ idx = 0
232
+ while sequence_ids[idx] != 1:
233
+ idx += 1
234
+ context_start = idx
235
+ while sequence_ids[idx] == 1:
236
+ idx += 1
237
+ context_end = idx - 1
238
+
239
+ # If the answer is not fully inside the context, label is (0, 0)
240
+ if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
241
+ start_positions.append(0)
242
+ end_positions.append(0)
243
+ else:
244
+ # Otherwise it's the start and end token positions
245
+ idx = context_start
246
+ while idx <= context_end and offset[idx][0] <= start_char:
247
+ idx += 1
248
+ start_positions.append(idx - 1)
249
+
250
+ idx = context_end
251
+ while idx >= context_start and offset[idx][1] >= end_char:
252
+ idx -= 1
253
+ end_positions.append(idx + 1)
254
+
255
+ inputs["start_positions"] = start_positions
256
+ inputs["end_positions"] = end_positions
257
+ return inputs
258
+
259
+ train_dataset = raw_datasets["train"].map(
260
+ preprocess_training_examples,
261
+ batched=True,
262
+ remove_columns=raw_datasets["train"].column_names,
263
+ )
264
+
265
+ def preprocess_validation_examples(examples):
266
+ questions = [q.strip() for q in examples["question"]]
267
+ inputs = tokenizer(
268
+ questions,
269
+ examples["context"],
270
+ max_length=max_length,
271
+ truncation="only_second",
272
+ stride=stride,
273
+ return_overflowing_tokens=True,
274
+ return_offsets_mapping=True,
275
+ padding="max_length",
276
+ )
277
+
278
+
279
+ sample_map = inputs.pop("overflow_to_sample_mapping")
280
+ example_ids = []
281
+
282
+ for i in range(len(inputs["input_ids"])):
283
+ sample_idx = sample_map[i]
284
+ example_ids.append(examples["id"][sample_idx])
285
+
286
+ sequence_ids = inputs.sequence_ids(i)
287
+ offset = inputs["offset_mapping"][i]
288
+ inputs["offset_mapping"][i] = [
289
+ o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
290
+ ]
291
+
292
+ inputs["example_id"] = example_ids
293
+ return inputs
294
+
295
+ validation_dataset = raw_datasets["validation"].map(
296
+ preprocess_validation_examples,
297
+ batched=True,
298
+ remove_columns=raw_datasets["validation"].column_names,
299
+ )
300
+
301
+ # --------------- PEFT -------------------- # One epoch without PEFT took about 2h on my computer with CUDA - performance of PEFT kinda ass though
302
+ if PEFT:
303
+ from peft import get_peft_config, get_peft_model, LoraConfig, TaskType
304
+
305
+ # ---- For all linear layers ----
306
+ import re
307
+ pattern = r'\((\w+)\): Linear'
308
+ linear_layers = re.findall(pattern, str(model.modules))
309
+ target_modules = list(set(linear_layers))
310
+
311
+ # If using peft, can consider increaisng r for better performance
312
+ peft_config = LoraConfig(
313
+ task_type=TaskType.QUESTION_ANS, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1, target_modules=target_modules, bias='all'
314
+ )
315
+
316
+ model = get_peft_model(model, peft_config)
317
+ model.print_trainable_parameters()
318
+
319
+ repo_name += "_PEFT"
320
+
321
+ # ------------------------------------------ #
322
+
323
+ args = TrainingArguments(
324
+ repo_name,
325
+ evaluation_strategy = "no",
326
+ save_strategy="epoch",
327
+ learning_rate=2e-5,
328
+ per_device_train_batch_size=batch_size,
329
+ per_device_eval_batch_size=batch_size,
330
+ num_train_epochs=3,
331
+ weight_decay=0.01,
332
+ push_to_hub=True,
333
+ fp16=fp16
334
+ )
335
+
336
+ trainer = Trainer(
337
+ model,
338
+ args,
339
+ train_dataset=train_dataset,
340
+ eval_dataset=validation_dataset,
341
+ data_collator=default_data_collator,
342
+ tokenizer=tokenizer
343
+ )
344
+
345
+ trainer.train()
346
+
347
+
348
+ # Not complete yet, still have post-processing, using HFHub to get results now
349
+ # https://huggingface.co/learn/nlp-course/chapter7/7?fw=pt