Spaces:
Sleeping
Sleeping
import re | |
import os | |
import pickle | |
import numpy as np | |
from tensorflow.keras.preprocessing.sequence import pad_sequences | |
import nltk | |
nltk.download('punkt') | |
nltk.download('punkt_tab') | |
nltk.download('stopwords') | |
from nltk.corpus import stopwords | |
STOP_WORDS = stopwords.words('english') | |
# Load the tokenizer from file | |
with open('./data/tokenizer.pickle', 'rb') as handle: | |
tokenizer = pickle.load(handle) | |
def clean_word(word): | |
""" | |
Cleans a word by removing non-alphanumeric characters and extra whitespaces, | |
converting it to lowercase, and checking if it is a stopword. | |
Args: | |
- word (str): the word to clean | |
Returns: | |
- str: the cleaned word, or an empty string if it is a stopword | |
""" | |
# remove non-alphanumeric characters and extra whitespaces | |
word = re.sub(r'[^\w\s]', '', word) | |
word = re.sub(r'\s+', ' ', word) | |
# convert to lowercase | |
word = word.lower() | |
if word not in STOP_WORDS: | |
return word | |
return '' | |
def tokenize_text(text): | |
""" | |
Tokenizes a text into a list of cleaned words. | |
Args: | |
- text (str): the text to tokenize | |
Returns: | |
- tokens (list of str): the list of cleaned words | |
- start_end_ranges (list of tuples): the start and end character positions for each token | |
""" | |
regex_match = r'[^\s\u200a\-\u2010-\u2015\u2212\uff0d]+' # Regex to match words | |
tokens = [] | |
start_end_ranges = [] | |
# Tokenize the sentences in the text | |
sentences = nltk.sent_tokenize(text) | |
start = 0 | |
for sentence in sentences: | |
sentence_tokens = re.findall(regex_match, sentence) | |
curr_sent_tokens = [] | |
curr_sent_ranges = [] | |
for word in sentence_tokens: | |
word = clean_word(word) | |
if word.strip(): | |
start = text.lower().find(word, start) | |
end = start + len(word) | |
curr_sent_ranges.append((start, end)) | |
curr_sent_tokens.append(word) | |
start = end | |
if len(curr_sent_tokens) > 0: | |
tokens.append(curr_sent_tokens) | |
start_end_ranges.append(curr_sent_ranges) | |
return tokens, start_end_ranges | |
def predict_multi_line_text(text, model, index_to_label, acronyms_to_entities, MAX_LENGTH): | |
""" | |
Predicts named entities for multi-line input text. | |
Args: | |
- text (str): The input text | |
- model: The trained NER model | |
- index_to_label: Dictionary mapping index to label | |
- acronyms_to_entities: Dictionary mapping acronyms to entity names | |
- MAX_LENGTH: Maximum input length for the model | |
Returns: | |
- entities: A list of named entities in the format (start, end, label) | |
""" | |
sequences = [] | |
sent_tokens, sent_start_end = tokenize_text(text) | |
for i in range(len(sent_tokens)): | |
sequence = tokenizer.texts_to_sequences([' '.join(token for token in sent_tokens[i])]) | |
sequences.extend(sequence) | |
padded_sequence = pad_sequences(sequences, maxlen=MAX_LENGTH, padding='post') | |
# Make the prediction | |
prediction = model.predict(np.array(padded_sequence)) | |
# Decode the prediction | |
predicted_labels = np.argmax(prediction, axis=-1) | |
predicted_labels = [ | |
[index_to_label[i] for i in sent_predicted_labels] | |
for sent_predicted_labels in predicted_labels | |
] | |
entities = [] | |
for tokens, sent_pred_labels, start_end_ranges in zip(sent_tokens, predicted_labels, sent_start_end): | |
for i, (token, label, start_end_range) in enumerate(zip(tokens, sent_pred_labels, start_end_ranges)): | |
start = start_end_range[0] | |
end = start_end_range[1] | |
if label not in ['O', '<PAD>']: | |
entity_type = acronyms_to_entities[label[2:]] | |
entity = (start, end, entity_type) | |
entities.append(entity) | |
return entities | |