Spaces:
Sleeping
Sleeping
import streamlit as st | |
import tensorflow as tf | |
import os | |
# Import your utility functions | |
from utils import ( | |
predict_multi_line_text, | |
tokenizer, | |
) | |
from config import index_to_label, acronyms_to_entities, MAX_LENGTH | |
from metrics import precision, recall, f1_score | |
# Register the custom metric functions | |
tf.keras.utils.get_custom_objects()[precision.__name__] = precision | |
tf.keras.utils.get_custom_objects()[recall.__name__] = recall | |
tf.keras.utils.get_custom_objects()[f1_score.__name__] = f1_score | |
# Load your trained model | |
model_dir = './model' # Adjust the path as needed | |
model_1 = tf.keras.models.load_model(os.path.join(model_dir, 'model_1.h5')) | |
# Define label colors for different entity types suitable for dark background | |
LABEL_COLORS = { | |
'Activity': '#FF7F50', # Coral | |
'Administration': '#6495ED', # Cornflower Blue | |
'Age': '#FFB6C1', # Light Pink | |
'Area': '#7FFF00', # Chartreuse | |
'Biological_attribute': '#FFD700', # Gold | |
'Biological_structure': '#00FA9A', # Medium Spring Green | |
'Clinical_event': '#BA55D3', # Medium Orchid | |
'Color': '#00CED1', # Dark Turquoise | |
'Coreference': '#FFA07A', # Light Salmon | |
'Date': '#ADFF2F', # Green Yellow | |
'Detailed_description': '#DA70D6', # Orchid | |
'Diagnostic_procedure': '#87CEFA', # Light Sky Blue | |
'Disease_disorder': '#FF4500', # Orange Red | |
'Distance': '#32CD32', # Lime Green | |
'Dosage': '#8A2BE2', # Blue Violet | |
'Duration': '#F08080', # Light Coral | |
'Family_history': '#20B2AA', # Light Sea Green | |
'Frequency': '#FF6347', # Tomato | |
'Height': '#4682B4', # Steel Blue | |
'History': '#EE82EE', # Violet | |
'Lab_value': '#FFDAB9', # Peach Puff | |
'Mass': '#7B68EE', # Medium Slate Blue | |
'Medication': '#00FF7F', # Spring Green | |
'Nonbiological_location': '#FF69B4', # Hot Pink | |
'Occupation': '#BDB76B', # Dark Khaki | |
'Other_entity': '#D3D3D3', # Light Grey | |
'Other_event': '#FF1493', # Deep Pink | |
'Outcome': '#00BFFF', # Deep Sky Blue | |
'Personal_background': '#00FFFF', # Aqua | |
'Qualitative_concept': '#FFA500', # Orange | |
'Quantitative_concept': '#FFA500', # Orange (same as above) | |
'Severity': '#1E90FF', # Dodger Blue | |
'Sex': '#FF00FF', # Magenta | |
'Shape': '#40E0D0', # Turquoise | |
'Sign_symptom': '#FF69B4', # Hot Pink | |
'Subject': '#F0E68C', # Khaki | |
'Texture': '#98FB98', # Pale Green | |
'Therapeutic_procedure': '#8B008B', # Dark Magenta | |
'Time': '#DC143C', # Crimson | |
'Volume': '#5F9EA0', # Cadet Blue | |
'Weight': '#FA8072', # Salmon | |
} | |
# Define the prediction function | |
def predict_ner(text): | |
try: | |
# Predict entities | |
entities = predict_multi_line_text( | |
text, | |
model_1, | |
index_to_label, | |
acronyms_to_entities, | |
MAX_LENGTH | |
) | |
# Sort entities by their start position | |
entities = sorted(entities, key=lambda x: x[0]) | |
# Build HTML string with highlighted entities | |
html_output = "" | |
last_idx = 0 | |
for start, end, label in entities: | |
# Append text before the entity | |
if last_idx < start: | |
html_output += text[last_idx:start] | |
# Get the color for the label, default to light grey if not specified | |
color = LABEL_COLORS.get(label, '#D3D3D3') # Light grey | |
# Wrap the entity with a span tag including style | |
entity_text = text[start:end] | |
# Include the label next to the entity | |
html_output += f'''<span style="background-color: {color}; font-weight: bold; padding: 2px; border-radius: 4px; margin: 1px;">{entity_text} <span style="font-size: smaller; font-weight: normal;">[{label}]</span></span>''' | |
last_idx = end | |
# Append any remaining text | |
if last_idx < len(text): | |
html_output += text[last_idx:] | |
return html_output | |
except Exception as e: | |
return f"<p style='color:red;'>Error: {str(e)}</p>" | |
# Set up the Streamlit app with dark theme | |
st.set_page_config(page_title="Medical NER", page_icon="🩺", layout="wide") | |
# Apply custom CSS for dark background and text colors | |
st.markdown( | |
""" | |
<style> | |
/* Main app background */ | |
.stApp { | |
background-color: #2E2E2E; | |
color: #FFFFFF; | |
} | |
/* Text input area */ | |
.stTextArea textarea { | |
background-color: #1E1E1E; | |
color: #FFFFFF; | |
} | |
/* Adjust the Analyze button */ | |
div.stButton > button:first-child { | |
background-color: #1E90FF; | |
color: #FFFFFF; | |
} | |
/* Scrollbar styling */ | |
::-webkit-scrollbar { | |
width: 10px; | |
} | |
::-webkit-scrollbar-track { | |
background: #1E1E1E; | |
} | |
::-webkit-scrollbar-thumb { | |
background: #888; | |
} | |
::-webkit-scrollbar-thumb:hover { | |
background: #555; | |
} | |
/* Style for the highlighted entities */ | |
.highlighted-entity { | |
padding: 2px; | |
border-radius: 4px; | |
margin: 1px; | |
font-weight: bold; | |
display: inline-block; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True | |
) | |
st.title("🩺 Medical Named Entity Recognition") | |
st.markdown(""" | |
Enter medical text below to identify and highlight entities such as diseases, medications, and anatomical terms. | |
""") | |
# Input text area | |
text_input = st.text_area("Enter medical text here:", height=200) | |
# Analyze button | |
if st.button("Analyze"): | |
if text_input.strip(): | |
with st.spinner("Analyzing..."): | |
result = predict_ner(text_input) | |
# Display the result with HTML rendering | |
st.markdown(f"<div style='font-size: 18px;'>{result}</div>", unsafe_allow_html=True) | |
else: | |
st.warning("Please enter some text to analyze.") |