File size: 3,755 Bytes
1f516b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import argparse
from typing import List
import torch
import numpy as np

from .model import build_model

from .dataset import NERDataset, get_collate_fn

from huggingface_hub import hf_hub_download

from .utils import get_class_to_index

class ChemNER:

    def __init__(self, model_path, device = None, cache_dir = None):

        self.args = self._get_args(cache_dir)

        states = torch.load(model_path, map_location = torch.device('cpu'))

        if device is None:
            device = torch.device('cpu')

        self.device = device

        self.model = self.get_model(self.args, device, states['state_dict'])

        self.collate = get_collate_fn()

        self.dataset = NERDataset(self.args, data_file = None)

        self.class_to_index = get_class_to_index(self.args.corpus)

        self.index_to_class = {self.class_to_index[key]: key for key in self.class_to_index}

    def _get_args(self, cache_dir):
        parser = argparse.ArgumentParser()

        parser.add_argument('--roberta_checkpoint', default = 'dmis-lab/biobert-large-cased-v1.1', type=str, help='which roberta config to use')

        parser.add_argument('--corpus', default = "chemdner", type=str, help="which corpus should the tags be from")

        args = parser.parse_args([])

        args.cache_dir = cache_dir

        return args

    def get_model(self, args, device, model_states):
        model = build_model(args)

        def remove_prefix(state_dict):
            return {k.replace('model.', ''): v for k, v in state_dict.items()}

        model.load_state_dict(remove_prefix(model_states), strict = False)

        model.to(device)

        model.eval()

        return model
    
    def predict_strings(self, strings: List, batch_size = 8):
        device = self.device

        predictions = []

        def prepare_output(char_span, prediction):
            toreturn = []


            i = 0

            while i < len(char_span):
                if prediction[i][0] == 'B':
                    toreturn.append((prediction[i][2:], [char_span[i].start, char_span[i].end]))




                elif len(toreturn) > 0 and prediction[i][2:] == toreturn[-1][0]:
                    toreturn[-1] = (toreturn[-1][0], [toreturn[-1][1][0], char_span[i].end])

                    

                i += 1


            return toreturn

        output = []
        for idx in range(0, len(strings), batch_size):
            batch_strings = strings[idx:idx+batch_size]
            batch_strings_tokenized = [(self.dataset.tokenizer(s, truncation = True, max_length = 512),  torch.Tensor([-1]), torch.Tensor([-1]) ) for s in batch_strings]


            sentences, masks, refs = self.collate(batch_strings_tokenized)

            predictions = self.model(input_ids = sentences.to(device), attention_mask = masks.to(device))[0].argmax(dim = 2).to('cpu')

            sentences_list = list(sentences)

            predictions_list = list(predictions)


            char_spans = []
            for j, sentence in enumerate(sentences_list):
                to_add = [batch_strings_tokenized[j][0].token_to_chars(i) for i, word in enumerate(sentence) if len(self.dataset.tokenizer.decode(int(word.item()), skip_special_tokens = True)) > 0 ]
                char_spans.append(to_add)

            class_predictions = [[self.index_to_class[int(pred.item())] for (pred, word) in zip(sentence_p, sentence_w) if len(self.dataset.tokenizer.decode(int(word.item()), skip_special_tokens = True)) > 0] for (sentence_p, sentence_w) in zip(predictions_list, sentences_list)]


            
            output+=[prepare_output(char_span, prediction) for char_span, prediction in zip(char_spans, class_predictions)]
        
        return output