abuzarAli's picture
Update app.py
78f0c67 verified
raw
history blame
3.7 kB
import os
import streamlit as st
from PIL import Image
import torch
from torchvision import transforms, models
import numpy as np
from groq import Groq
# Set up environment variables
os.environ["GROQ_API_KEY"] = "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"
# Initialize Groq client
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
# Load Pretrained Models
@st.cache_resource
# Load Pretrained Model for Organ Recognition
@st.cache_resource
def load_organ_model():
model = models.resnet18(pretrained=True) # Load pretrained ResNet18
num_features = model.fc.in_features # Get the number of input features to the final layer
model.fc = torch.nn.Linear(num_features, 4) # Modify the final layer for 4 classes
model.eval() # Set the model to evaluation mode
return model
# Image Preprocessing
def preprocess_image(image):
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
return transform(image).unsqueeze(0)
# Groq API for AI Insights
def get_ai_insights(text_prompt):
try:
response = client.chat.completions.create(
messages=[{"role": "user", "content": text_prompt}],
model="llama-3.3-70b-versatile"
)
return response.choices[0].message.content
except Exception as e:
return f"Error: {e}"
# Organ Recognition Prediction
def predict_organ(image):
with torch.no_grad():
input_tensor = preprocess_image(image)
output = organ_model(input_tensor)
# Check the output dimensions
st.write(f"Model output shape: {output.shape}")
# Ensure the output matches the number of classes
classes = ["Lungs", "Heart", "Spine", "Other"]
if output.size(1) != len(classes):
raise ValueError(
f"Model output size ({output.size(1)}) does not match the number of classes ({len(classes)})."
)
# Get the prediction
prediction_index = output.argmax().item()
prediction = classes[prediction_index]
return prediction
# Predict Normal/Abnormal
def predict_normal_abnormal(image):
with torch.no_grad():
output = chexnet_model(preprocess_image(image))
classes = ["Normal", "Abnormal"]
prediction = classes[output.argmax().item()]
return prediction
# Streamlit App
st.title("Medical X-ray Analysis App")
st.sidebar.title("Navigation")
task = st.sidebar.radio("Select a task", ["Upload X-ray", "AI Insights"])
if task == "Upload X-ray":
uploaded_file = st.file_uploader("Upload an X-ray image", type=["jpg", "png", "jpeg"])
if uploaded_file:
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded X-ray", use_column_width=True)
# Predict Organ
st.subheader("Step 1: Identify the Organ")
organ = predict_organ(image)
st.write(f"Predicted Organ: **{organ}**")
# Predict Normal/Abnormal
st.subheader("Step 2: Analyze the X-ray")
classification = predict_normal_abnormal(image)
st.write(f"X-ray Status: **{classification}**")
if classification == "Abnormal":
st.subheader("Step 3: AI-Based Insights")
ai_prompt = f"Explain why this X-ray of the {organ} is abnormal."
insights = get_ai_insights(ai_prompt)
st.write(insights)
elif task == "AI Insights":
st.subheader("Ask AI")
user_input = st.text_area("Enter your query for AI insights")
if user_input:
response = get_ai_insights(user_input)
st.write(response)