ChemEagle_API / chemiener /interface.py
CYF200127's picture
Upload 162 files
1f516b6 verified
import os
import argparse
from typing import List
import torch
import numpy as np
from .model import build_model
from .dataset import NERDataset, get_collate_fn
from huggingface_hub import hf_hub_download
from .utils import get_class_to_index
class ChemNER:
def __init__(self, model_path, device = None, cache_dir = None):
self.args = self._get_args(cache_dir)
states = torch.load(model_path, map_location = torch.device('cpu'))
if device is None:
device = torch.device('cpu')
self.device = device
self.model = self.get_model(self.args, device, states['state_dict'])
self.collate = get_collate_fn()
self.dataset = NERDataset(self.args, data_file = None)
self.class_to_index = get_class_to_index(self.args.corpus)
self.index_to_class = {self.class_to_index[key]: key for key in self.class_to_index}
def _get_args(self, cache_dir):
parser = argparse.ArgumentParser()
parser.add_argument('--roberta_checkpoint', default = 'dmis-lab/biobert-large-cased-v1.1', type=str, help='which roberta config to use')
parser.add_argument('--corpus', default = "chemdner", type=str, help="which corpus should the tags be from")
args = parser.parse_args([])
args.cache_dir = cache_dir
return args
def get_model(self, args, device, model_states):
model = build_model(args)
def remove_prefix(state_dict):
return {k.replace('model.', ''): v for k, v in state_dict.items()}
model.load_state_dict(remove_prefix(model_states), strict = False)
model.to(device)
model.eval()
return model
def predict_strings(self, strings: List, batch_size = 8):
device = self.device
predictions = []
def prepare_output(char_span, prediction):
toreturn = []
i = 0
while i < len(char_span):
if prediction[i][0] == 'B':
toreturn.append((prediction[i][2:], [char_span[i].start, char_span[i].end]))
elif len(toreturn) > 0 and prediction[i][2:] == toreturn[-1][0]:
toreturn[-1] = (toreturn[-1][0], [toreturn[-1][1][0], char_span[i].end])
i += 1
return toreturn
output = []
for idx in range(0, len(strings), batch_size):
batch_strings = strings[idx:idx+batch_size]
batch_strings_tokenized = [(self.dataset.tokenizer(s, truncation = True, max_length = 512), torch.Tensor([-1]), torch.Tensor([-1]) ) for s in batch_strings]
sentences, masks, refs = self.collate(batch_strings_tokenized)
predictions = self.model(input_ids = sentences.to(device), attention_mask = masks.to(device))[0].argmax(dim = 2).to('cpu')
sentences_list = list(sentences)
predictions_list = list(predictions)
char_spans = []
for j, sentence in enumerate(sentences_list):
to_add = [batch_strings_tokenized[j][0].token_to_chars(i) for i, word in enumerate(sentence) if len(self.dataset.tokenizer.decode(int(word.item()), skip_special_tokens = True)) > 0 ]
char_spans.append(to_add)
class_predictions = [[self.index_to_class[int(pred.item())] for (pred, word) in zip(sentence_p, sentence_w) if len(self.dataset.tokenizer.decode(int(word.item()), skip_special_tokens = True)) > 0] for (sentence_p, sentence_w) in zip(predictions_list, sentences_list)]
output+=[prepare_output(char_span, prediction) for char_span, prediction in zip(char_spans, class_predictions)]
return output