import streamlit as st | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import pandas as pd | |
import numpy as np | |
from sklearn.preprocessing import MultiLabelBinarizer | |
# Check if a GPU is available | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Load the trained model and tokenizer | |
def load_model(): | |
model = AutoModelForSequenceClassification.from_pretrained( | |
"microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract", | |
num_labels=8, # Adjust based on your label count | |
problem_type="multi_label_classification" | |
) | |
model.load_state_dict(torch.load('best_model_v2.pth')) | |
model.eval() | |
tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract") | |
model = model.to(device) # Move the model to the correct device | |
return model, tokenizer | |
def load_mlb(): | |
# Define the classes based on your label set | |
# classes = ['E11.9', 'I10', 'J45.909', 'M54.5', 'N39.0', '81001.0', '99213.0', '99214.0'] | |
classes = ['81001.0','99213.0','99214.0','E11.9','I10','J45.909','M54.5','N39.0'] | |
# Initialize and fit the MultiLabelBinarizer | |
mlb = MultiLabelBinarizer(classes=classes) | |
mlb.fit([classes]) # Fit with the full list of labels as a single sample | |
return mlb | |
# # Load MultiLabelBinarizer | |
# @st.cache_resource | |
# def load_mlb(): | |
# mlb = MultiLabelBinarizer() | |
# # mlb.classes_ = np.load('mlb_classes.npy') # Assuming you saved the classes array during training | |
# mlb = MultiLabelBinarizer(classes=['E11.9', 'I10', 'J45.909', 'M54.5', | |
# 'N39.0', '81001.0', '99213.0', '99214.0']) # Update with actual labels | |
# return mlb | |
model, tokenizer = load_model() | |
mlb = load_mlb() | |
# Streamlit UI | |
st.title("Automated Coding and Billing Prediction") | |
# st.write("Enter clinical notes to predict ICD and CPT codes.") | |
# Text input for Clinical Notes | |
clinical_note = st.text_area("Enter clinical notes to predict ICD and CPT codes") | |
# Prediction button | |
if st.button('Predict'): | |
if clinical_note: | |
# Tokenize the input clinical note | |
inputs = tokenizer(clinical_note, truncation=True, padding="max_length", max_length=512, return_tensors='pt') | |
# Move inputs to the GPU if available | |
inputs = {key: val.to(device) for key, val in inputs.items()} | |
# Model inference | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
# Apply sigmoid and threshold the output (0.5 for multi-label classification) | |
pred_labels = (torch.sigmoid(logits) > 0.5).cpu().numpy() | |
# Get the predicted ICD and CPT codes | |
predicted_codes = mlb.inverse_transform(pred_labels) | |
# Format the results for better display | |
if predicted_codes: | |
st.write("**Predicted ICD and CPT Codes:**") | |
for codes in predicted_codes: | |
for code in codes: | |
if code in ['81001.0', '99213.0', '99214.0']: # Adjust based on your CPT code list | |
st.write(f"- **CPT Code:** {code}") | |
else: | |
st.write(f"- **ICD Code:** {code}") | |
else: | |
st.write("No codes predicted.") | |
# else: | |
# st.write("Please enter clinical notes for prediction.") | |
# # Prediction button | |
# if st.button('Predict'): | |
# if clinical_note: | |
# # Tokenize the input clinical note | |
# inputs = tokenizer(clinical_note, truncation=True, padding="max_length", max_length=512, return_tensors='pt') | |
# # Move inputs to the GPU if available | |
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# inputs = {key: val.to(device) for key, val in inputs.items()} | |
# # Model inference | |
# with torch.no_grad(): | |
# outputs = model(**inputs) | |
# logits = outputs.logits | |
# # Apply sigmoid and threshold the output (0.5 for multi-label classification) | |
# pred_labels = (torch.sigmoid(logits) > 0.5).cpu().numpy() | |
# # Get the predicted ICD and CPT codes | |
# predicted_codes = mlb.inverse_transform(pred_labels) | |
# # Show the results | |
# st.write("Predicted ICD and CPT Codes:") | |
# st.write(predicted_codes) | |
# else: | |
# st.write("Please enter clinical notes for prediction.") | |