kavithapadala's picture
Upload app.py
463a6b5 verified
raw
history blame
3.33 kB
import streamlit as st
from PIL import Image
import torch
from torchvision import transforms
from transformers import AutoModelForImageClassification
import pandas as pd
# Load your model
@st.cache_data
def load_dataset():
dataset_path = "./Data_Entry_2017_v2020.csv" # Replace with your dataset path
return pd.read_csv(dataset_path)
data = load_dataset()
@st.cache_resource
def load_model():
# Define the model architecture
model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k", num_labels=15)
# Load the saved state dictionary
state_dict = torch.load("best_model_new_retrain.pth", map_location=torch.device('cpu'))
model.load_state_dict(state_dict)
model.eval()
return model
model = load_model()
# Define image transformation
transform = transforms.Compose([
transforms.Resize((224, 224)), # Adjust based on your model's requirements
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet stats
])
# Function to make predictions
def predict_image(image):
image = transform(image).unsqueeze(0) # Add batch dimension
with torch.no_grad():
outputs = model(image).logits
probabilities = torch.sigmoid(outputs)
return probabilities
# Streamlit App
st.title("Chest Xray Disease Prediction App")
st.write("Upload single or multiple images to get predictions.")
# File uploader for single or bulk images
uploaded_files = st.file_uploader("Upload Image(s)", type=["jpg", "png", "jpeg"], accept_multiple_files=True)
# Process each uploaded file
if uploaded_files:
for uploaded_file in uploaded_files:
# Load and display the image
image = Image.open(uploaded_file).convert("RGB")
st.image(image, caption=f"Uploaded Image: {uploaded_file.name}", use_column_width=True)
# Search for the filename in the dataset
uploaded_filename = uploaded_file.name
matching_row = data[data['Image Index'] == uploaded_filename]
truth = matching_row.iloc[0]['Finding Labels'] if not matching_row.empty else "No matching label found"
st.write(f"**Truth (Ground Truth Labels):** {truth}")
# Get predictions
probabilities = predict_image(image)
# Create a DataFrame to display probabilities
label_columns = [
'No Finding', 'Infiltration', 'Effusion', 'Atelectasis', 'Nodule',
'Mass', 'Pneumothorax', 'Consolidation', 'Pleural_Thickening',
'Cardiomegaly', 'Emphysema', 'Edema', 'Fibrosis', 'Pneumonia', 'Hernia'
]
prediction_df = pd.DataFrame({
"Class": label_columns,
"Probability": probabilities.squeeze().tolist()
})
# Highlight the highest probabilities (you can customize the threshold)
prediction_df['Highlight'] = prediction_df['Probability'] > 0.5
# Display predictions
st.write("**Prediction (Model Probabilities):**")
st.dataframe(
prediction_df.style.format({"Probability": "{:.2f}"}).applymap(
lambda val: 'background-color: yellow;' if val else '', subset=['Highlight']
)
)