abuzarAli's picture
Update app.py
4c0ed57 verified
raw
history blame
3.27 kB
import os
import streamlit as st
from PIL import Image
import torch
from torchvision import transforms, models
import numpy as np
from groq import Groq
# Set up environment variables
os.environ["GROQ_API_KEY"] = "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"
# Initialize Groq client
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
# Load Pretrained Models
@st.cache_resource
def load_model():
# Pretrained EfficientNet for organ recognition
organ_model = models.efficientnet_b0(pretrained=True)
organ_model.eval()
# Pretrained DenseNet (CheXNet) for normal/abnormal classification
chexnet_model = models.densenet121(pretrained=True)
chexnet_model.classifier = torch.nn.Linear(1024, 2) # Normal, Abnormal
chexnet_model.eval()
return organ_model, chexnet_model
organ_model, chexnet_model = load_model()
# Image Preprocessing
def preprocess_image(image):
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
return transform(image).unsqueeze(0)
# Groq API for AI Insights
def get_ai_insights(text_prompt):
try:
response = client.chat.completions.create(
messages=[{"role": "user", "content": text_prompt}],
model="llama-3.3-70b-versatile"
)
return response.choices[0].message.content
except Exception as e:
return f"Error: {e}"
# Predict Organ
def predict_organ(image):
with torch.no_grad():
output = organ_model(preprocess_image(image))
classes = ["Lungs", "Heart", "Spine", "Other"] # Example classes
prediction = classes[output.argmax().item()]
return prediction
# Predict Normal/Abnormal
def predict_normal_abnormal(image):
with torch.no_grad():
output = chexnet_model(preprocess_image(image))
classes = ["Normal", "Abnormal"]
prediction = classes[output.argmax().item()]
return prediction
# Streamlit App
st.title("Medical X-ray Analysis App")
st.sidebar.title("Navigation")
task = st.sidebar.radio("Select a task", ["Upload X-ray", "AI Insights"])
if task == "Upload X-ray":
uploaded_file = st.file_uploader("Upload an X-ray image", type=["jpg", "png", "jpeg"])
if uploaded_file:
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded X-ray", use_column_width=True)
# Predict Organ
st.subheader("Step 1: Identify the Organ")
organ = predict_organ(image)
st.write(f"Predicted Organ: **{organ}**")
# Predict Normal/Abnormal
st.subheader("Step 2: Analyze the X-ray")
classification = predict_normal_abnormal(image)
st.write(f"X-ray Status: **{classification}**")
if classification == "Abnormal":
st.subheader("Step 3: AI-Based Insights")
ai_prompt = f"Explain why this X-ray of the {organ} is abnormal."
insights = get_ai_insights(ai_prompt)
st.write(insights)
elif task == "AI Insights":
st.subheader("Ask AI")
user_input = st.text_area("Enter your query for AI insights")
if user_input:
response = get_ai_insights(user_input)
st.write(response)