File size: 3,924 Bytes
5bd622e
 
 
 
 
 
 
fad2c24
5bd622e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import re
import os
import pickle
import numpy as np
from tensorflow.keras.preprocessing.sequence import pad_sequences
import nltk
nltk.download('punkt')
nltk.download('punkt_tab')
nltk.download('stopwords')
from nltk.corpus import stopwords

STOP_WORDS = stopwords.words('english')

# Load the tokenizer from file
with open('./data/tokenizer.pickle', 'rb') as handle:
    tokenizer = pickle.load(handle)

def clean_word(word):
    """
    Cleans a word by removing non-alphanumeric characters and extra whitespaces,
    converting it to lowercase, and checking if it is a stopword.

    Args:
    - word (str): the word to clean

    Returns:
    - str: the cleaned word, or an empty string if it is a stopword
    """
    # remove non-alphanumeric characters and extra whitespaces
    word = re.sub(r'[^\w\s]', '', word)
    word = re.sub(r'\s+', ' ', word)
    
    # convert to lowercase
    word = word.lower()
    
    if word not in STOP_WORDS:
        return word
    
    return ''

def tokenize_text(text):
    """
    Tokenizes a text into a list of cleaned words.

    Args:
    - text (str): the text to tokenize

    Returns:
    - tokens (list of str): the list of cleaned words
    - start_end_ranges (list of tuples): the start and end character positions for each token
    """
    regex_match = r'[^\s\u200a\-\u2010-\u2015\u2212\uff0d]+'  # Regex to match words
    tokens = []
    start_end_ranges = []
    # Tokenize the sentences in the text
    sentences = nltk.sent_tokenize(text)
    
    start = 0
    for sentence in sentences:
        
        sentence_tokens = re.findall(regex_match, sentence)
        curr_sent_tokens = []
        curr_sent_ranges = []
        
        for word in sentence_tokens:
            word = clean_word(word)
            if word.strip():
                start = text.lower().find(word, start)
                end = start + len(word)
                curr_sent_ranges.append((start, end))
                curr_sent_tokens.append(word)
                start = end
        if len(curr_sent_tokens) > 0:
            tokens.append(curr_sent_tokens)
            start_end_ranges.append(curr_sent_ranges)
            
    return tokens, start_end_ranges

def predict_multi_line_text(text, model, index_to_label, acronyms_to_entities, MAX_LENGTH):
    """
    Predicts named entities for multi-line input text.
    
    Args:
    - text (str): The input text
    - model: The trained NER model
    - index_to_label: Dictionary mapping index to label
    - acronyms_to_entities: Dictionary mapping acronyms to entity names
    - MAX_LENGTH: Maximum input length for the model
    
    Returns:
    - entities: A list of named entities in the format (start, end, label)
    """
    
    sequences = []
    sent_tokens, sent_start_end = tokenize_text(text)
    
    for i in range(len(sent_tokens)):
        sequence = tokenizer.texts_to_sequences([' '.join(token for token in sent_tokens[i])])
        sequences.extend(sequence)
    
    padded_sequence = pad_sequences(sequences, maxlen=MAX_LENGTH, padding='post')
    
    # Make the prediction
    prediction = model.predict(np.array(padded_sequence))
    
    # Decode the prediction
    predicted_labels = np.argmax(prediction, axis=-1)
    
    predicted_labels = [
        [index_to_label[i] for i in sent_predicted_labels]
        for sent_predicted_labels in predicted_labels
    ]
    
    entities = []
    for tokens, sent_pred_labels, start_end_ranges in zip(sent_tokens, predicted_labels, sent_start_end):
        for i, (token, label, start_end_range) in enumerate(zip(tokens, sent_pred_labels, start_end_ranges)):
            start = start_end_range[0]
            end = start_end_range[1]
            if label not in ['O', '<PAD>']:
                entity_type = acronyms_to_entities[label[2:]]
                entity = (start, end, entity_type)
                entities.append(entity)
    
    return entities