Spaces:
Sleeping
Sleeping
| import torch | |
| from torch.utils.data.sampler import RandomSampler, SequentialSampler | |
| from torch.utils.data import DataLoader | |
| from datasets.arrow_dataset import Dataset as HFDataset | |
| from datasets.load import load_metric, load_dataset | |
| from transformers import AutoTokenizer, DataCollatorForTokenClassification, BertConfig | |
| from transformers import default_data_collator, EvalPrediction | |
| import numpy as np | |
| import logging | |
| from tasks.qa.utils_qa import postprocess_qa_predictions | |
| class SQuAD: | |
| def __init__(self, tokenizer: AutoTokenizer, data_args, training_args, qa_args) -> None: | |
| self.data_args = data_args | |
| self.training_args = training_args | |
| self.qa_args = qa_args | |
| self.version_2 = data_args.dataset_name == "squad_v2" | |
| raw_datasets = load_dataset(data_args.dataset_name) | |
| column_names = raw_datasets['train'].column_names | |
| self.question_column_name = "question" | |
| self.context_column_name = "context" | |
| self.answer_column_name = "answers" | |
| self.tokenizer = tokenizer | |
| self.pad_on_right = tokenizer.padding_side == "right" # True | |
| self.max_seq_len = 384 #data_args.max_seq_length | |
| if training_args.do_train: | |
| self.train_dataset = raw_datasets['train'] | |
| self.train_dataset = self.train_dataset.map( | |
| self.prepare_train_dataset, | |
| batched=True, | |
| remove_columns=column_names, | |
| load_from_cache_file=True, | |
| desc="Running tokenizer on train dataset", | |
| ) | |
| if data_args.max_train_samples is not None: | |
| self.train_dataset = self.train_dataset.select(range(data_args.max_train_samples)) | |
| if training_args.do_eval: | |
| self.eval_examples = raw_datasets['validation'] | |
| if data_args.max_eval_samples is not None: | |
| self.eval_examples = self.eval_examples.select(range(data_args.max_eval_samples)) | |
| self.eval_dataset = self.eval_examples.map( | |
| self.prepare_eval_dataset, | |
| batched=True, | |
| remove_columns=column_names, | |
| load_from_cache_file=True, | |
| desc="Running tokenizer on validation dataset", | |
| ) | |
| if data_args.max_eval_samples is not None: | |
| self.eval_dataset = self.eval_dataset.select(range(data_args.max_eval_samples)) | |
| self.predict_dataset = None | |
| self.data_collator = default_data_collator | |
| self.metric = load_metric(data_args.dataset_name) | |
| def prepare_train_dataset(self, examples): | |
| examples['question'] = [q.lstrip() for q in examples['question']] | |
| tokenized = self.tokenizer( | |
| examples['question' if self.pad_on_right else 'context'], | |
| examples['context' if self.pad_on_right else 'question'], | |
| truncation='only_second' if self.pad_on_right else 'only_first', | |
| max_length=self.max_seq_len, | |
| stride=128, | |
| return_overflowing_tokens=True, | |
| return_offsets_mapping=True, | |
| padding="max_length", | |
| ) | |
| sample_maping = tokenized.pop("overflow_to_sample_mapping") | |
| offset_mapping = tokenized.pop("offset_mapping") | |
| tokenized["start_positions"] = [] | |
| tokenized["end_positions"] = [] | |
| for i, offsets in enumerate(offset_mapping): | |
| input_ids = tokenized['input_ids'][i] | |
| cls_index = input_ids.index(self.tokenizer.cls_token_id) | |
| sequence_ids = tokenized.sequence_ids(i) | |
| sample_index = sample_maping[i] | |
| answers = examples['answers'][sample_index] | |
| if len(answers['answer_start']) == 0: | |
| tokenized["start_positions"].append(cls_index) | |
| tokenized["end_positions"].append(cls_index) | |
| else: | |
| start_char = answers["answer_start"][0] | |
| end_char = start_char + len(answers["text"][0]) | |
| token_start_index = 0 | |
| while sequence_ids[token_start_index] != (1 if self.pad_on_right else 0): | |
| token_start_index += 1 | |
| token_end_index = len(input_ids) - 1 | |
| while sequence_ids[token_end_index] != (1 if self.pad_on_right else 0): | |
| token_end_index -= 1 | |
| # Detect if the answer is out of the span | |
| # (in which case this feature is labeled with the CLS index). | |
| if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char): | |
| tokenized["start_positions"].append(cls_index) | |
| tokenized["end_positions"].append(cls_index) | |
| else: | |
| # Otherwise move the token_start_index and token_end_index to the two ends of the answer. | |
| # Note: we could go after the last offset if the answer is the last word (edge case). | |
| while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char: | |
| token_start_index += 1 | |
| tokenized["start_positions"].append(token_start_index - 1) | |
| while offsets[token_end_index][1] >= end_char: | |
| token_end_index -= 1 | |
| tokenized["end_positions"].append(token_end_index + 1) | |
| return tokenized | |
| def prepare_eval_dataset(self, examples): | |
| # if self.version_2: | |
| examples['question'] = [q.lstrip() for q in examples['question']] | |
| tokenized = self.tokenizer( | |
| examples['question' if self.pad_on_right else 'context'], | |
| examples['context' if self.pad_on_right else 'question'], | |
| truncation='only_second' if self.pad_on_right else 'only_first', | |
| max_length=self.max_seq_len, | |
| stride=128, | |
| return_overflowing_tokens=True, | |
| return_offsets_mapping=True, | |
| padding="max_length", | |
| ) | |
| sample_mapping = tokenized.pop("overflow_to_sample_mapping") | |
| tokenized["example_id"] = [] | |
| for i in range(len(tokenized["input_ids"])): | |
| # Grab the sequence corresponding to that example (to know what is the context and what is the question). | |
| sequence_ids = tokenized.sequence_ids(i) | |
| context_index = 1 if self.pad_on_right else 0 | |
| # One example can give several spans, this is the index of the example containing this span of text. | |
| sample_index = sample_mapping[i] | |
| tokenized["example_id"].append(examples["id"][sample_index]) | |
| # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token | |
| # position is part of the context or not. | |
| tokenized["offset_mapping"][i] = [ | |
| (o if sequence_ids[k] == context_index else None) | |
| for k, o in enumerate(tokenized["offset_mapping"][i]) | |
| ] | |
| return tokenized | |
| def compute_metrics(self, p: EvalPrediction): | |
| return self.metric.compute(predictions=p.predictions, references=p.label_ids) | |
| def post_processing_function(self, examples, features, predictions, stage='eval'): | |
| predictions = postprocess_qa_predictions( | |
| examples=examples, | |
| features=features, | |
| predictions=predictions, | |
| version_2_with_negative=self.version_2, | |
| n_best_size=self.qa_args.n_best_size, | |
| max_answer_length=self.qa_args.max_answer_length, | |
| null_score_diff_threshold=self.qa_args.null_score_diff_threshold, | |
| output_dir=self.training_args.output_dir, | |
| prefix=stage, | |
| log_level=logging.INFO | |
| ) | |
| if self.version_2: # squad_v2 | |
| formatted_predictions = [ | |
| {"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items() | |
| ] | |
| else: | |
| formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()] | |
| references = [{"id": ex["id"], "answers": ex['answers']} for ex in examples] | |
| return EvalPrediction(predictions=formatted_predictions, label_ids=references) | |