Spaces:
Sleeping
Sleeping
File size: 5,858 Bytes
1f516b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import os
import cv2
import copy
import random
import json
import contextlib
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
from transformers import BertTokenizerFast, AutoTokenizer, RobertaTokenizerFast
from .utils import get_class_to_index
class NERDataset(Dataset):
def __init__(self, args, data_file, split='train'):
super().__init__()
self.args = args
if data_file:
data_path = os.path.join(args.data_path, data_file)
with open(data_path) as f:
self.data = json.load(f)
self.name = os.path.basename(data_file).split('.')[0]
self.split = split
self.is_train = (split == 'train')
self.tokenizer = AutoTokenizer.from_pretrained(self.args.roberta_checkpoint, cache_dir = self.args.cache_dir)#BertTokenizerFast.from_pretrained('allenai/scibert_scivocab_uncased')
self.class_to_index = get_class_to_index(self.args.corpus)
self.index_to_class = {self.class_to_index[key]: key for key in self.class_to_index}
#commment
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
text_tokenized = self.tokenizer(self.data[str(idx)]['text'], truncation = True, max_length = self.args.max_seq_length)
if len(text_tokenized['input_ids']) > 512: print(len(text_tokenized['input_ids']))
text_tokenized_untruncated = self.tokenizer(self.data[str(idx)]['text'])
return text_tokenized, self.align_labels(text_tokenized, self.data[str(idx)]['entities'], len(self.data[str(idx)]['text'])), self.align_labels(text_tokenized_untruncated, self.data[str(idx)]['entities'], len(self.data[str(idx)]['text']))
def align_labels(self, text_tokenized, entities, length):
char_to_class = {}
for entity in entities:
for span in entities[entity]["span"]:
for i in range(span[0], span[1]):
char_to_class[i] = self.class_to_index[('B-' if i == span[0] else 'I-')+str(entities[entity]["type"])]
for i in range(length):
if i not in char_to_class:
char_to_class[i] = 0
classes = []
for i in range(len(text_tokenized[0])):
span = text_tokenized.token_to_chars(i)
if span is not None:
classes.append(char_to_class[span.start])
else:
classes.append(-100)
return torch.LongTensor(classes)
def make_html(word_tokens, predictions):
toreturn = '''<!DOCTYPE html>
<html>
<head>
<title>Named Entity Recognition Visualization</title>
<style>
.EXAMPLE_LABEL {
color: red;
text-decoration: underline red;
}
.REACTION_PRODUCT {
color: orange;
text-decoration: underline orange;
}
.STARTING_MATERIAL {
color: gold;
text-decoration: underline gold;
}
.REAGENT_CATALYST {
color: green;
text-decoration: underline green;
}
.SOLVENT {
color: cyan;
text-decoration: underline cyan;
}
.OTHER_COMPOUND {
color: blue;
text-decoration: underline blue;
}
.TIME {
color: purple;
text-decoration: underline purple;
}
.TEMPERATURE {
color: magenta;
text-decoration: underline magenta;
}
.YIELD_OTHER {
color: palegreen;
text-decoration: underline palegreen;
}
.YIELD_PERCENT {
color: pink;
text-decoration: underline pink;
}
</style>
</head>
<body>
<p>'''
last_label = None
for idx, item in enumerate(word_tokens):
decoded = self.tokenizer.decode(item, skip_special_tokens = True)
if len(decoded)>0:
if idx!=0 and decoded[0]!='#':
toreturn+=" "
label = predictions[idx]
if label == last_label:
toreturn+=decoded if decoded[0]!="#" else decoded[2:]
else:
if last_label is not None and last_label>0:
toreturn+="</u>"
if label >0:
toreturn+="<u class=\""
toreturn+=self.index_to_class[label]
toreturn+="\">"
toreturn+=decoded if decoded[0]!="#" else decoded[2:]
if label == 0:
toreturn+=decoded if decoded[0]!="#" else decoded[2:]
if idx==len(word_tokens) and label>0:
toreturn+="</u>"
last_label = label
toreturn += ''' </p>
</body>
</html>'''
return toreturn
def get_collate_fn():
def collate(batch):
sentences = []
masks = []
refs = []
for ex in batch:
sentences.append(torch.LongTensor(ex[0]['input_ids']))
masks.append(torch.Tensor(ex[0]['attention_mask']))
refs.append(ex[1])
sentences = pad_sequence(sentences, batch_first = True, padding_value = 0)
masks = pad_sequence(masks, batch_first = True, padding_value = 0)
refs = pad_sequence(refs, batch_first = True, padding_value = -100)
return sentences, masks, refs
return collate
|