File size: 2,847 Bytes
280d87f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
from transformers import BertForTokenClassification

from .config_train import device, model_load_path, tokenizer
from .DataProcessing import read_input
from .load_data import sorted_tags


class Key_Ner_Predictor:
    def __init__(self, model_path, tokenizer, device, tag_map):
        """
        Initialize the Key_Ner_Predictor with the model, tokenizer, and device.

        Args:
            model_path (str): Path to the pre-trained model.
            tokenizer (BertTokenizer): Tokenizer to process input sentences.
            device (torch.device): Device to run the model on.
            tag_map (Dict[int, str]): Mapping of indices to tags.
        """
        self.model = BertForTokenClassification.from_pretrained(model_path).to(device)
        self.tokenizer = tokenizer
        self.device = device
        self.tag_map = tag_map

    def predict(self, sentence):
        """
        Predict the tags for each token in the given sentence.

        Args:
            sentence (str): Input sentence to predict.

        Returns:
            Tuple[str, List[str]]: The original sentence and its predicted tags.
        """
        # Process the sentence
        sentence = read_input(sentence)

        # Tokenize the sentence
        input_ids = self.tokenizer.encode(sentence, return_tensors="pt").to(self.device)

        # Create attention masks
        attention_masks = (input_ids != self.tokenizer.pad_token_id).float().to(self.device)

        # Set model to evaluation mode
        self.model.eval()

        with torch.no_grad():
            # Forward pass
            outputs = self.model(input_ids, token_type_ids=None, attention_mask=attention_masks)
            logits = outputs.logits

        # Get predicted tags for each token in the sentence
        predicted_tags = torch.argmax(logits, dim=2).cpu().numpy()[0]

        # Map indices to tags
        predicted_tags = [self.tag_map[idx] for idx in predicted_tags]
        predicted_tags = set(predicted_tags)
        predicted_tags.remove('<pad>')
        predicted_tags = list(predicted_tags)

        for index in range(len(predicted_tags)):
            predicted_tags[index] = predicted_tags[index].replace(" ", "_")

        return self.tokenizer.decode(input_ids[0], skip_special_tokens=True), predicted_tags


# Initialize the Key_Ner_Predictor
predictor = Key_Ner_Predictor(
    model_path=model_load_path,
    tokenizer=tokenizer,
    device=device,
    tag_map=dict(enumerate(sorted_tags))
)

# # Define the sentence to predict
# sentence = "Tôi muốn đi cắm trại ngắm hoàng hôn trên biển cùng gia đình"

# # Get the prediction
# original_sentence, predicted_tags = predictor.predict(sentence)

# # Print the sentence and its predicted tags
# print("Sentence:", original_sentence)
# print("Predicted Tags:", predicted_tags)