botcon commited on
Commit
4a8a37c
1 Parent(s): 58958df

Delete LukeQuestionAnswering.py

Browse files
Files changed (1) hide show
  1. LukeQuestionAnswering.py +0 -431
LukeQuestionAnswering.py DELETED
@@ -1,431 +0,0 @@
1
- from transformers import LukePreTrainedModel, LukeModel, AutoTokenizer, TrainingArguments, default_data_collator, Trainer, LukeForQuestionAnswering
2
- from transformers.modeling_outputs import ModelOutput
3
- from typing import Optional, Tuple, Union
4
-
5
- import numpy as np
6
- from tqdm import tqdm
7
- import evaluate
8
- import torch
9
- from dataclasses import dataclass
10
- from datasets import load_dataset
11
- from torch import nn
12
- from torch.nn import CrossEntropyLoss
13
- import collections
14
-
15
- PEFT = False
16
- repo_name = "LUKE_squad_finetuned_qa"
17
- tf32 = True
18
- fp16= True
19
- train = False
20
- test = True
21
- trained_model = "LUKE_squad_finetuned_qa_tf32"
22
-
23
- torch.backends.cuda.matmul.allow_tf32 = tf32
24
- torch.backends.cudnn.allow_tf32 = tf32
25
- device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
26
-
27
- if tf32:
28
- repo_name += "_tf32"
29
-
30
- # https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/luke/modeling_luke.py#L319-L353
31
- # Taken from HF repository, easier to include additional features -- Currently identical to LukeForQuestionAnswering by HF
32
-
33
- @dataclass
34
- class LukeQuestionAnsweringModelOutput(ModelOutput):
35
- """
36
- Outputs of question answering models.
37
-
38
-
39
- Args:
40
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
41
- Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
42
- start_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
43
- Span-start scores (before SoftMax).
44
- end_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
45
- Span-end scores (before SoftMax).
46
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
47
- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
48
- one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
49
-
50
-
51
- Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
52
- entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
53
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
54
- shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
55
- layer plus the initial entity embedding outputs.
56
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
57
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
58
- sequence_length)`.
59
-
60
-
61
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
62
- heads.
63
- """
64
-
65
-
66
- loss: Optional[torch.FloatTensor] = None
67
- start_logits: torch.FloatTensor = None
68
- end_logits: torch.FloatTensor = None
69
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
70
- entity_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
71
- attentions: Optional[Tuple[torch.FloatTensor]] = None
72
-
73
- class AugmentedLukeForQuestionAnswering(LukePreTrainedModel):
74
- def __init__(self, config):
75
- super().__init__(config)
76
-
77
- # This is 2.
78
- self.num_labels = config.num_labels
79
-
80
- self.luke = LukeModel(config, add_pooling_layer=False)
81
-
82
- '''
83
- Any improvement to the model are expected here. Additional features, anything...
84
- '''
85
- self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
86
-
87
-
88
- # Initialize weights and apply final processing
89
- self.post_init()
90
-
91
- def forward(
92
- self,
93
- input_ids: Optional[torch.LongTensor] = None,
94
- attention_mask: Optional[torch.FloatTensor] = None,
95
- token_type_ids: Optional[torch.LongTensor] = None,
96
- position_ids: Optional[torch.FloatTensor] = None,
97
- entity_ids: Optional[torch.LongTensor] = None,
98
- entity_attention_mask: Optional[torch.FloatTensor] = None,
99
- entity_token_type_ids: Optional[torch.LongTensor] = None,
100
- entity_position_ids: Optional[torch.LongTensor] = None,
101
- head_mask: Optional[torch.FloatTensor] = None,
102
- inputs_embeds: Optional[torch.FloatTensor] = None,
103
- start_positions: Optional[torch.LongTensor] = None,
104
- end_positions: Optional[torch.LongTensor] = None,
105
- output_attentions: Optional[bool] = None,
106
- output_hidden_states: Optional[bool] = None,
107
- return_dict: Optional[bool] = None,
108
- ) -> Union[Tuple, LukeQuestionAnsweringModelOutput]:
109
-
110
- r"""
111
- start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
112
- Labels for position (index) of the start of the labelled span for computing the token classification loss.
113
- Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
114
- are not taken into account for computing the loss.
115
- end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
116
- Labels for position (index) of the end of the labelled span for computing the token classification loss.
117
- Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
118
- are not taken into account for computing the loss.
119
- """
120
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
121
-
122
-
123
- outputs = self.luke(
124
- input_ids=input_ids,
125
- attention_mask=attention_mask,
126
- token_type_ids=token_type_ids,
127
- position_ids=position_ids,
128
- entity_ids=entity_ids,
129
- entity_attention_mask=entity_attention_mask,
130
- entity_token_type_ids=entity_token_type_ids,
131
- entity_position_ids=entity_position_ids,
132
- head_mask=head_mask,
133
- inputs_embeds=inputs_embeds,
134
- output_attentions=output_attentions,
135
- output_hidden_states=output_hidden_states,
136
- return_dict=True,
137
- )
138
-
139
-
140
- sequence_output = outputs.last_hidden_state
141
-
142
-
143
- logits = self.qa_outputs(sequence_output)
144
- start_logits, end_logits = logits.split(1, dim=-1)
145
- start_logits = start_logits.squeeze(-1)
146
- end_logits = end_logits.squeeze(-1)
147
-
148
-
149
- total_loss = None
150
- if start_positions is not None and end_positions is not None:
151
- # If we are on multi-GPU, split add a dimension
152
- if len(start_positions.size()) > 1:
153
- start_positions = start_positions.squeeze(-1)
154
- if len(end_positions.size()) > 1:
155
- end_positions = end_positions.squeeze(-1)
156
- # sometimes the start/end positions are outside our model inputs, we ignore these terms
157
- ignored_index = start_logits.size(1)
158
- start_positions.clamp_(0, ignored_index)
159
- end_positions.clamp_(0, ignored_index)
160
-
161
- loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
162
- start_loss = loss_fct(start_logits, start_positions)
163
- end_loss = loss_fct(end_logits, end_positions)
164
- total_loss = (start_loss + end_loss) / 2
165
-
166
-
167
- if not return_dict:
168
- return tuple(
169
- v
170
- for v in [
171
- total_loss,
172
- start_logits,
173
- end_logits,
174
- outputs.hidden_states,
175
- outputs.entity_hidden_states,
176
- outputs.attentions,
177
- ]
178
- if v is not None
179
- )
180
-
181
-
182
- return LukeQuestionAnsweringModelOutput(
183
- loss=total_loss,
184
- start_logits=start_logits,
185
- end_logits=end_logits,
186
- hidden_states=outputs.hidden_states,
187
- entity_hidden_states=outputs.entity_hidden_states,
188
- attentions=outputs.attentions,
189
- )
190
-
191
- if __name__ == "__main__":
192
- # Setting up tokenizer and helper functions
193
- # Work-around for FastTokenizer - RoBERTa and LUKE share the same subword vocab, and we are not using entities functions of LUKE-tokenizer anyways
194
- tokenizer = AutoTokenizer.from_pretrained("roberta-base")
195
-
196
- # Necessary initialization
197
- max_length = 384
198
- stride = 128
199
- batch_size = 8
200
- n_best = 20
201
- max_answer_length = 30
202
- metric = evaluate.load("squad")
203
- raw_datasets = load_dataset("squad")
204
-
205
- def compute_metrics(start_logits, end_logits, features, examples):
206
- example_to_features = collections.defaultdict(list)
207
- for idx, feature in enumerate(features):
208
- example_to_features[feature["example_id"]].append(idx)
209
-
210
- predicted_answers = []
211
- for example in tqdm(examples):
212
- example_id = example["id"]
213
- context = example["context"]
214
- answers = []
215
-
216
- # Loop through all features associated with that example
217
- for feature_index in example_to_features[example_id]:
218
- start_logit = start_logits[feature_index]
219
- end_logit = end_logits[feature_index]
220
- offsets = features[feature_index]["offset_mapping"]
221
-
222
- start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
223
- end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
224
- for start_index in start_indexes:
225
- for end_index in end_indexes:
226
- # Skip answers that are not fully in the context
227
- if offsets[start_index] is None or offsets[end_index] is None:
228
- continue
229
- # Skip answers with a length that is either < 0 or > max_answer_length
230
- if (
231
- end_index < start_index
232
- or end_index - start_index + 1 > max_answer_length
233
- ):
234
- continue
235
-
236
- answer = {
237
- "text": context[offsets[start_index][0] : offsets[end_index][1]],
238
- "logit_score": start_logit[start_index] + end_logit[end_index],
239
- }
240
- answers.append(answer)
241
-
242
- # Select the answer with the best score
243
- if len(answers) > 0:
244
- best_answer = max(answers, key=lambda x: x["logit_score"])
245
- predicted_answers.append(
246
- {"id": example_id, "prediction_text": best_answer["text"]}
247
- )
248
- else:
249
- predicted_answers.append({"id": example_id, "prediction_text": ""})
250
-
251
- theoretical_answers = [{"id": ex["id"], "answers": ex["answers"]} for ex in examples]
252
- return metric.compute(predictions=predicted_answers, references=theoretical_answers)
253
-
254
- def preprocess_training_examples(examples):
255
-
256
- questions = [q.strip() for q in examples["question"]]
257
- inputs = tokenizer(
258
- questions,
259
- examples["context"],
260
- max_length=max_length,
261
- truncation="only_second",
262
- stride=stride,
263
- return_overflowing_tokens=True,
264
- return_offsets_mapping=True,
265
- padding="max_length",
266
- )
267
-
268
- offset_mapping = inputs.pop("offset_mapping")
269
- sample_map = inputs.pop("overflow_to_sample_mapping")
270
- answers = examples["answers"]
271
- start_positions = []
272
- end_positions = []
273
-
274
- for i, offset in enumerate(offset_mapping):
275
- sample_idx = sample_map[i]
276
- answer = answers[sample_idx]
277
- start_char = answer["answer_start"][0]
278
- end_char = answer["answer_start"][0] + len(answer["text"][0])
279
- sequence_ids = inputs.sequence_ids(i)
280
-
281
- # Find the start and end of the context
282
- idx = 0
283
- while sequence_ids[idx] != 1:
284
- idx += 1
285
- context_start = idx
286
- while sequence_ids[idx] == 1:
287
- idx += 1
288
- context_end = idx - 1
289
-
290
- # If the answer is not fully inside the context, label is (0, 0)
291
- if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
292
- start_positions.append(0)
293
- end_positions.append(0)
294
- else:
295
- # Otherwise it's the start and end token positions
296
- idx = context_start
297
- while idx <= context_end and offset[idx][0] <= start_char:
298
- idx += 1
299
- start_positions.append(idx - 1)
300
-
301
- idx = context_end
302
- while idx >= context_start and offset[idx][1] >= end_char:
303
- idx -= 1
304
- end_positions.append(idx + 1)
305
-
306
- inputs["start_positions"] = start_positions
307
- inputs["end_positions"] = end_positions
308
- return inputs
309
-
310
- def preprocess_validation_examples(examples):
311
- questions = [q.strip() for q in examples["question"]]
312
- inputs = tokenizer(
313
- questions,
314
- examples["context"],
315
- max_length=max_length,
316
- truncation="only_second",
317
- stride=stride,
318
- return_overflowing_tokens=True,
319
- return_offsets_mapping=True,
320
- padding="max_length",
321
- )
322
-
323
-
324
- sample_map = inputs.pop("overflow_to_sample_mapping")
325
- example_ids = []
326
-
327
- for i in range(len(inputs["input_ids"])):
328
- sample_idx = sample_map[i]
329
- example_ids.append(examples["id"][sample_idx])
330
-
331
- sequence_ids = inputs.sequence_ids(i)
332
- offset = inputs["offset_mapping"][i]
333
- inputs["offset_mapping"][i] = [
334
- o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
335
- ]
336
-
337
- inputs["example_id"] = example_ids
338
- return inputs
339
-
340
- if train:
341
- base_luke = "studio-ousia/luke-base"
342
-
343
- # tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
344
- model = AugmentedLukeForQuestionAnswering.from_pretrained(base_luke).to(device)
345
-
346
- train_dataset = raw_datasets["train"].map(
347
- preprocess_training_examples,
348
- batched=True,
349
- remove_columns=raw_datasets["train"].column_names,
350
- )
351
-
352
- validation_dataset = raw_datasets["validation"].map(
353
- preprocess_validation_examples,
354
- batched=True,
355
- remove_columns=raw_datasets["validation"].column_names,
356
- )
357
-
358
- # --------------- PEFT -------------------- # One epoch without PEFT took about 2h on my computer with CUDA - performance of PEFT kinda ass though
359
- if PEFT:
360
- from peft import get_peft_config, get_peft_model, LoraConfig, TaskType
361
-
362
- # ---- For all linear layers ----
363
- import re
364
- pattern = r'\((\w+)\): Linear'
365
- linear_layers = re.findall(pattern, str(model.modules))
366
- target_modules = list(set(linear_layers))
367
-
368
- # If using peft, can consider increaisng r for better performance
369
- peft_config = LoraConfig(
370
- task_type=TaskType.QUESTION_ANS, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1, target_modules=target_modules, bias='all'
371
- )
372
-
373
- model = get_peft_model(model, peft_config)
374
- model.print_trainable_parameters()
375
-
376
- repo_name += "_PEFT"
377
-
378
- # ------------------------------------------ #
379
-
380
- args = TrainingArguments(
381
- repo_name,
382
- evaluation_strategy = "no",
383
- save_strategy="epoch",
384
- learning_rate=2e-5,
385
- per_device_train_batch_size=batch_size,
386
- per_device_eval_batch_size=batch_size,
387
- num_train_epochs=3,
388
- weight_decay=0.01,
389
- push_to_hub=True,
390
- fp16=fp16
391
- )
392
-
393
- trainer = Trainer(
394
- model,
395
- args,
396
- train_dataset=train_dataset,
397
- eval_dataset=validation_dataset,
398
- data_collator=default_data_collator,
399
- tokenizer=tokenizer
400
- )
401
-
402
- trainer.train()
403
-
404
- elif test:
405
- model = AugmentedLukeForQuestionAnswering.from_pretrained(trained_model).to(device)
406
-
407
- interval = len(raw_datasets["validation"]) // 100
408
- exact_match = 0
409
- f1 = 0
410
-
411
- with torch.no_grad():
412
- for i in range(1, 101):
413
- start = interval * (i - 1)
414
- end = interval * i
415
- small_eval_set = raw_datasets["validation"].select(range(start ,end))
416
- eval_set = small_eval_set.map(
417
- preprocess_validation_examples,
418
- batched=True,
419
- remove_columns=raw_datasets["validation"].column_names
420
- )
421
- eval_set_for_model = eval_set.remove_columns(["example_id", "offset_mapping"])
422
- eval_set_for_model.set_format("torch")
423
- batch = {k: eval_set_for_model[k].to(device) for k in eval_set_for_model.column_names}
424
- outputs = model(**batch)
425
- start_logits = outputs.start_logits.cpu().numpy()
426
- end_logits = outputs.end_logits.cpu().numpy()
427
- res = compute_metrics(start_logits, end_logits, eval_set, small_eval_set)
428
- exact_match += res['exact_match']
429
- f1 += res["f1"]
430
- print("F1 score: {}".format(f1 / 100))
431
- print("Exact match: {}".format(exact_match / 100))