Ajay Karthick Senthil Kumar
update
d5bc1d8
import streamlit as st
import tensorflow as tf
import os
# Import your utility functions
from utils import (
predict_multi_line_text,
tokenizer,
)
from config import index_to_label, acronyms_to_entities, MAX_LENGTH
from metrics import precision, recall, f1_score
# Register the custom metric functions
tf.keras.utils.get_custom_objects()[precision.__name__] = precision
tf.keras.utils.get_custom_objects()[recall.__name__] = recall
tf.keras.utils.get_custom_objects()[f1_score.__name__] = f1_score
# Load your trained model
model_dir = './model' # Adjust the path as needed
model_1 = tf.keras.models.load_model(os.path.join(model_dir, 'model_1.h5'))
# Define label colors for different entity types suitable for dark background
LABEL_COLORS = {
'Activity': '#FF7F50', # Coral
'Administration': '#6495ED', # Cornflower Blue
'Age': '#FFB6C1', # Light Pink
'Area': '#7FFF00', # Chartreuse
'Biological_attribute': '#FFD700', # Gold
'Biological_structure': '#00FA9A', # Medium Spring Green
'Clinical_event': '#BA55D3', # Medium Orchid
'Color': '#00CED1', # Dark Turquoise
'Coreference': '#FFA07A', # Light Salmon
'Date': '#ADFF2F', # Green Yellow
'Detailed_description': '#DA70D6', # Orchid
'Diagnostic_procedure': '#87CEFA', # Light Sky Blue
'Disease_disorder': '#FF4500', # Orange Red
'Distance': '#32CD32', # Lime Green
'Dosage': '#8A2BE2', # Blue Violet
'Duration': '#F08080', # Light Coral
'Family_history': '#20B2AA', # Light Sea Green
'Frequency': '#FF6347', # Tomato
'Height': '#4682B4', # Steel Blue
'History': '#EE82EE', # Violet
'Lab_value': '#FFDAB9', # Peach Puff
'Mass': '#7B68EE', # Medium Slate Blue
'Medication': '#00FF7F', # Spring Green
'Nonbiological_location': '#FF69B4', # Hot Pink
'Occupation': '#BDB76B', # Dark Khaki
'Other_entity': '#D3D3D3', # Light Grey
'Other_event': '#FF1493', # Deep Pink
'Outcome': '#00BFFF', # Deep Sky Blue
'Personal_background': '#00FFFF', # Aqua
'Qualitative_concept': '#FFA500', # Orange
'Quantitative_concept': '#FFA500', # Orange (same as above)
'Severity': '#1E90FF', # Dodger Blue
'Sex': '#FF00FF', # Magenta
'Shape': '#40E0D0', # Turquoise
'Sign_symptom': '#FF69B4', # Hot Pink
'Subject': '#F0E68C', # Khaki
'Texture': '#98FB98', # Pale Green
'Therapeutic_procedure': '#8B008B', # Dark Magenta
'Time': '#DC143C', # Crimson
'Volume': '#5F9EA0', # Cadet Blue
'Weight': '#FA8072', # Salmon
}
# Define the prediction function
def predict_ner(text):
try:
# Predict entities
entities = predict_multi_line_text(
text,
model_1,
index_to_label,
acronyms_to_entities,
MAX_LENGTH
)
# Sort entities by their start position
entities = sorted(entities, key=lambda x: x[0])
# Build HTML string with highlighted entities
html_output = ""
last_idx = 0
for start, end, label in entities:
# Append text before the entity
if last_idx < start:
html_output += text[last_idx:start]
# Get the color for the label, default to light grey if not specified
color = LABEL_COLORS.get(label, '#D3D3D3') # Light grey
# Wrap the entity with a span tag including style
entity_text = text[start:end]
# Include the label next to the entity
html_output += f'''<span style="background-color: {color}; font-weight: bold; padding: 2px; border-radius: 4px; margin: 1px;">{entity_text} <span style="font-size: smaller; font-weight: normal;">[{label}]</span></span>'''
last_idx = end
# Append any remaining text
if last_idx < len(text):
html_output += text[last_idx:]
return html_output
except Exception as e:
return f"<p style='color:red;'>Error: {str(e)}</p>"
# Set up the Streamlit app with dark theme
st.set_page_config(page_title="Medical NER", page_icon="🩺", layout="wide")
# Apply custom CSS for dark background and text colors
st.markdown(
"""
<style>
/* Main app background */
.stApp {
background-color: #2E2E2E;
color: #FFFFFF;
}
/* Text input area */
.stTextArea textarea {
background-color: #1E1E1E;
color: #FFFFFF;
}
/* Adjust the Analyze button */
div.stButton > button:first-child {
background-color: #1E90FF;
color: #FFFFFF;
}
/* Scrollbar styling */
::-webkit-scrollbar {
width: 10px;
}
::-webkit-scrollbar-track {
background: #1E1E1E;
}
::-webkit-scrollbar-thumb {
background: #888;
}
::-webkit-scrollbar-thumb:hover {
background: #555;
}
/* Style for the highlighted entities */
.highlighted-entity {
padding: 2px;
border-radius: 4px;
margin: 1px;
font-weight: bold;
display: inline-block;
}
</style>
""",
unsafe_allow_html=True
)
st.title("🩺 Medical Named Entity Recognition")
st.markdown("""
Enter medical text below to identify and highlight entities such as diseases, medications, and anatomical terms.
""")
# Input text area
text_input = st.text_area("Enter medical text here:", height=200)
# Analyze button
if st.button("Analyze"):
if text_input.strip():
with st.spinner("Analyzing..."):
result = predict_ner(text_input)
# Display the result with HTML rendering
st.markdown(f"<div style='font-size: 18px;'>{result}</div>", unsafe_allow_html=True)
else:
st.warning("Please enter some text to analyze.")