import os import cv2 import copy import random import json import contextlib import numpy as np import pandas as pd import torch import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence from transformers import BertTokenizerFast, AutoTokenizer, RobertaTokenizerFast from .utils import get_class_to_index class NERDataset(Dataset): def __init__(self, args, data_file, split='train'): super().__init__() self.args = args if data_file: data_path = os.path.join(args.data_path, data_file) with open(data_path) as f: self.data = json.load(f) self.name = os.path.basename(data_file).split('.')[0] self.split = split self.is_train = (split == 'train') self.tokenizer = AutoTokenizer.from_pretrained(self.args.roberta_checkpoint, cache_dir = self.args.cache_dir)#BertTokenizerFast.from_pretrained('allenai/scibert_scivocab_uncased') self.class_to_index = get_class_to_index(self.args.corpus) self.index_to_class = {self.class_to_index[key]: key for key in self.class_to_index} #commment def __len__(self): return len(self.data) def __getitem__(self, idx): text_tokenized = self.tokenizer(self.data[str(idx)]['text'], truncation = True, max_length = self.args.max_seq_length) if len(text_tokenized['input_ids']) > 512: print(len(text_tokenized['input_ids'])) text_tokenized_untruncated = self.tokenizer(self.data[str(idx)]['text']) return text_tokenized, self.align_labels(text_tokenized, self.data[str(idx)]['entities'], len(self.data[str(idx)]['text'])), self.align_labels(text_tokenized_untruncated, self.data[str(idx)]['entities'], len(self.data[str(idx)]['text'])) def align_labels(self, text_tokenized, entities, length): char_to_class = {} for entity in entities: for span in entities[entity]["span"]: for i in range(span[0], span[1]): char_to_class[i] = self.class_to_index[('B-' if i == span[0] else 'I-')+str(entities[entity]["type"])] for i in range(length): if i not in char_to_class: char_to_class[i] = 0 classes = [] for i in range(len(text_tokenized[0])): span = text_tokenized.token_to_chars(i) if span is not None: classes.append(char_to_class[span.start]) else: classes.append(-100) return torch.LongTensor(classes) def make_html(word_tokens, predictions): toreturn = '''
''' last_label = None for idx, item in enumerate(word_tokens): decoded = self.tokenizer.decode(item, skip_special_tokens = True) if len(decoded)>0: if idx!=0 and decoded[0]!='#': toreturn+=" " label = predictions[idx] if label == last_label: toreturn+=decoded if decoded[0]!="#" else decoded[2:] else: if last_label is not None and last_label>0: toreturn+="" if label >0: toreturn+="" toreturn+=decoded if decoded[0]!="#" else decoded[2:] if label == 0: toreturn+=decoded if decoded[0]!="#" else decoded[2:] if idx==len(word_tokens) and label>0: toreturn+="" last_label = label toreturn += '''
''' return toreturn def get_collate_fn(): def collate(batch): sentences = [] masks = [] refs = [] for ex in batch: sentences.append(torch.LongTensor(ex[0]['input_ids'])) masks.append(torch.Tensor(ex[0]['attention_mask'])) refs.append(ex[1]) sentences = pad_sequence(sentences, batch_first = True, padding_value = 0) masks = pad_sequence(masks, batch_first = True, padding_value = 0) refs = pad_sequence(refs, batch_first = True, padding_value = -100) return sentences, masks, refs return collate