Spaces:
Sleeping
Sleeping
File size: 2,847 Bytes
280d87f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import torch
from transformers import BertForTokenClassification
from .config_train import device, model_load_path, tokenizer
from .DataProcessing import read_input
from .load_data import sorted_tags
class Key_Ner_Predictor:
def __init__(self, model_path, tokenizer, device, tag_map):
"""
Initialize the Key_Ner_Predictor with the model, tokenizer, and device.
Args:
model_path (str): Path to the pre-trained model.
tokenizer (BertTokenizer): Tokenizer to process input sentences.
device (torch.device): Device to run the model on.
tag_map (Dict[int, str]): Mapping of indices to tags.
"""
self.model = BertForTokenClassification.from_pretrained(model_path).to(device)
self.tokenizer = tokenizer
self.device = device
self.tag_map = tag_map
def predict(self, sentence):
"""
Predict the tags for each token in the given sentence.
Args:
sentence (str): Input sentence to predict.
Returns:
Tuple[str, List[str]]: The original sentence and its predicted tags.
"""
# Process the sentence
sentence = read_input(sentence)
# Tokenize the sentence
input_ids = self.tokenizer.encode(sentence, return_tensors="pt").to(self.device)
# Create attention masks
attention_masks = (input_ids != self.tokenizer.pad_token_id).float().to(self.device)
# Set model to evaluation mode
self.model.eval()
with torch.no_grad():
# Forward pass
outputs = self.model(input_ids, token_type_ids=None, attention_mask=attention_masks)
logits = outputs.logits
# Get predicted tags for each token in the sentence
predicted_tags = torch.argmax(logits, dim=2).cpu().numpy()[0]
# Map indices to tags
predicted_tags = [self.tag_map[idx] for idx in predicted_tags]
predicted_tags = set(predicted_tags)
predicted_tags.remove('<pad>')
predicted_tags = list(predicted_tags)
for index in range(len(predicted_tags)):
predicted_tags[index] = predicted_tags[index].replace(" ", "_")
return self.tokenizer.decode(input_ids[0], skip_special_tokens=True), predicted_tags
# Initialize the Key_Ner_Predictor
predictor = Key_Ner_Predictor(
model_path=model_load_path,
tokenizer=tokenizer,
device=device,
tag_map=dict(enumerate(sorted_tags))
)
# # Define the sentence to predict
# sentence = "Tôi muốn đi cắm trại ngắm hoàng hôn trên biển cùng gia đình"
# # Get the prediction
# original_sentence, predicted_tags = predictor.predict(sentence)
# # Print the sentence and its predicted tags
# print("Sentence:", original_sentence)
# print("Predicted Tags:", predicted_tags)
|