kavithapadala's picture
Update app.py
b3b69b9 verified
raw
history blame
3.4 kB
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import pandas as pd
import numpy as np
from sklearn.preprocessing import MultiLabelBinarizer
# Check if a GPU is available
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the trained model and tokenizer
@st.cache_resource
def load_model():
model = AutoModelForSequenceClassification.from_pretrained(
"microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract",
num_labels=8, # Adjust based on your label count
problem_type="multi_label_classification"
)
# Map the model to the appropriate device
model.load_state_dict(torch.load('best_model_v2.pth', map_location=torch.device('cpu')))
model.eval()
tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract")
return model, tokenizer
@st.cache_resource
def load_mlb():
# Define the classes based on your label set
classes = ['81001.0','99213.0','99214.0','E11.9','I10','J45.909','M54.5','N39.0']
# Initialize and fit the MultiLabelBinarizer
mlb = MultiLabelBinarizer(classes=classes)
mlb.fit([classes]) # Fit with the full list of labels as a single sample
return mlb
# # Load MultiLabelBinarizer
# @st.cache_resource
# def load_mlb():
# mlb = MultiLabelBinarizer()
# # mlb.classes_ = np.load('mlb_classes.npy') # Assuming you saved the classes array during training
# mlb = MultiLabelBinarizer(classes=['E11.9', 'I10', 'J45.909', 'M54.5',
# 'N39.0', '81001.0', '99213.0', '99214.0']) # Update with actual labels
# return mlb
model, tokenizer = load_model()
mlb = load_mlb()
# Streamlit UI
st.title("Automated Medical Coding Prediction")
# st.write("Enter clinical notes to predict ICD and CPT codes.")
# Text input for Clinical Notes
clinical_note = st.text_area("Enter clinical notes")
# Prediction button
if st.button('Predict'):
if clinical_note:
# Tokenize the input clinical note
inputs = tokenizer(clinical_note, truncation=True, padding="max_length", max_length=512, return_tensors='pt')
# Move inputs to the GPU if available
# inputs = {key: val.to(device) for key, val in inputs.items()}
inputs = {key: val.to(torch.device('cpu')) for key, val in inputs.items()}
# Model inference
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# Apply sigmoid and threshold the output (0.5 for multi-label classification)
pred_labels = (torch.sigmoid(logits) > 0.5).cpu().numpy()
# Get the predicted ICD and CPT codes
predicted_codes = mlb.inverse_transform(pred_labels)
# Format the results for better display
if predicted_codes:
st.write("**Predicted CPT and ICD Codes:**")
for codes in predicted_codes:
for code in codes:
if code in ['81001.0', '99213.0', '99214.0']: # Adjust based on your CPT code list
st.write(f"- **CPT Code:** {code}")
else:
st.write(f"- **ICD Code:** {code}")
else:
st.write("No codes predicted.")
else:
st.write("Please enter clinical notes for prediction.")