NER-medical-text / utils.py
Ajay Karthick Senthil Kumar
add nltk fix
fad2c24
import re
import os
import pickle
import numpy as np
from tensorflow.keras.preprocessing.sequence import pad_sequences
import nltk
nltk.download('punkt')
nltk.download('punkt_tab')
nltk.download('stopwords')
from nltk.corpus import stopwords
STOP_WORDS = stopwords.words('english')
# Load the tokenizer from file
with open('./data/tokenizer.pickle', 'rb') as handle:
tokenizer = pickle.load(handle)
def clean_word(word):
"""
Cleans a word by removing non-alphanumeric characters and extra whitespaces,
converting it to lowercase, and checking if it is a stopword.
Args:
- word (str): the word to clean
Returns:
- str: the cleaned word, or an empty string if it is a stopword
"""
# remove non-alphanumeric characters and extra whitespaces
word = re.sub(r'[^\w\s]', '', word)
word = re.sub(r'\s+', ' ', word)
# convert to lowercase
word = word.lower()
if word not in STOP_WORDS:
return word
return ''
def tokenize_text(text):
"""
Tokenizes a text into a list of cleaned words.
Args:
- text (str): the text to tokenize
Returns:
- tokens (list of str): the list of cleaned words
- start_end_ranges (list of tuples): the start and end character positions for each token
"""
regex_match = r'[^\s\u200a\-\u2010-\u2015\u2212\uff0d]+' # Regex to match words
tokens = []
start_end_ranges = []
# Tokenize the sentences in the text
sentences = nltk.sent_tokenize(text)
start = 0
for sentence in sentences:
sentence_tokens = re.findall(regex_match, sentence)
curr_sent_tokens = []
curr_sent_ranges = []
for word in sentence_tokens:
word = clean_word(word)
if word.strip():
start = text.lower().find(word, start)
end = start + len(word)
curr_sent_ranges.append((start, end))
curr_sent_tokens.append(word)
start = end
if len(curr_sent_tokens) > 0:
tokens.append(curr_sent_tokens)
start_end_ranges.append(curr_sent_ranges)
return tokens, start_end_ranges
def predict_multi_line_text(text, model, index_to_label, acronyms_to_entities, MAX_LENGTH):
"""
Predicts named entities for multi-line input text.
Args:
- text (str): The input text
- model: The trained NER model
- index_to_label: Dictionary mapping index to label
- acronyms_to_entities: Dictionary mapping acronyms to entity names
- MAX_LENGTH: Maximum input length for the model
Returns:
- entities: A list of named entities in the format (start, end, label)
"""
sequences = []
sent_tokens, sent_start_end = tokenize_text(text)
for i in range(len(sent_tokens)):
sequence = tokenizer.texts_to_sequences([' '.join(token for token in sent_tokens[i])])
sequences.extend(sequence)
padded_sequence = pad_sequences(sequences, maxlen=MAX_LENGTH, padding='post')
# Make the prediction
prediction = model.predict(np.array(padded_sequence))
# Decode the prediction
predicted_labels = np.argmax(prediction, axis=-1)
predicted_labels = [
[index_to_label[i] for i in sent_predicted_labels]
for sent_predicted_labels in predicted_labels
]
entities = []
for tokens, sent_pred_labels, start_end_ranges in zip(sent_tokens, predicted_labels, sent_start_end):
for i, (token, label, start_end_range) in enumerate(zip(tokens, sent_pred_labels, start_end_ranges)):
start = start_end_range[0]
end = start_end_range[1]
if label not in ['O', '<PAD>']:
entity_type = acronyms_to_entities[label[2:]]
entity = (start, end, entity_type)
entities.append(entity)
return entities