import streamlit as st import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification import pandas as pd import numpy as np from sklearn.preprocessing import MultiLabelBinarizer # Check if a GPU is available # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load the trained model and tokenizer @st.cache_resource def load_model(): model = AutoModelForSequenceClassification.from_pretrained( "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract", num_labels=8, # Adjust based on your label count problem_type="multi_label_classification" ) # Map the model to the appropriate device model.load_state_dict(torch.load('best_model_v2.pth', map_location=torch.device('cpu'))) model.eval() tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract") return model, tokenizer @st.cache_resource def load_mlb(): # Define the classes based on your label set classes = ['81001.0','99213.0','99214.0','E11.9','I10','J45.909','M54.5','N39.0'] # Initialize and fit the MultiLabelBinarizer mlb = MultiLabelBinarizer(classes=classes) mlb.fit([classes]) # Fit with the full list of labels as a single sample return mlb # # Load MultiLabelBinarizer # @st.cache_resource # def load_mlb(): # mlb = MultiLabelBinarizer() # # mlb.classes_ = np.load('mlb_classes.npy') # Assuming you saved the classes array during training # mlb = MultiLabelBinarizer(classes=['E11.9', 'I10', 'J45.909', 'M54.5', # 'N39.0', '81001.0', '99213.0', '99214.0']) # Update with actual labels # return mlb model, tokenizer = load_model() mlb = load_mlb() # Streamlit UI st.title("Automated Medical Coding") # st.write("Enter clinical notes to predict ICD and CPT codes.") # Text input for Clinical Notes clinical_note = st.text_area("Enter clinical notes") # Prediction button if st.button('Predict'): if clinical_note: # Tokenize the input clinical note inputs = tokenizer(clinical_note, truncation=True, padding="max_length", max_length=512, return_tensors='pt') # Move inputs to the GPU if available # inputs = {key: val.to(device) for key, val in inputs.items()} inputs = {key: val.to(torch.device('cpu')) for key, val in inputs.items()} # Model inference with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits # Apply sigmoid and threshold the output (0.5 for multi-label classification) pred_labels = (torch.sigmoid(logits) > 0.5).cpu().numpy() # Get the predicted ICD and CPT codes predicted_codes = mlb.inverse_transform(pred_labels) # Format the results for better display if predicted_codes: st.write("**Predicted CPT and ICD Codes:**") for codes in predicted_codes: for code in codes: if code in ['81001.0', '99213.0', '99214.0']: # Adjust based on your CPT code list st.write(f"- **CPT Code:** {code}") else: st.write(f"- **ICD Code:** {code}") else: st.write("No codes predicted.") else: st.write("Please enter clinical notes for prediction.")