File size: 3,274 Bytes
8a359f9
4265985
8a359f9
4c0ed57
 
 
 
 
 
 
8a359f9
4c0ed57
 
 
 
 
 
 
 
 
8066bd0
4c0ed57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4265985
 
 
4c0ed57
4265985
4c0ed57
 
 
 
4265985
4c0ed57
 
 
 
4265985
4c0ed57
 
 
 
 
4265985
4c0ed57
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
import streamlit as st
from PIL import Image
import torch
from torchvision import transforms, models
import numpy as np
from groq import Groq

# Set up environment variables
os.environ["GROQ_API_KEY"] = "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"

# Initialize Groq client
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))

# Load Pretrained Models
@st.cache_resource
def load_model():
    # Pretrained EfficientNet for organ recognition
    organ_model = models.efficientnet_b0(pretrained=True)
    organ_model.eval()
    
    # Pretrained DenseNet (CheXNet) for normal/abnormal classification
    chexnet_model = models.densenet121(pretrained=True)
    chexnet_model.classifier = torch.nn.Linear(1024, 2)  # Normal, Abnormal
    chexnet_model.eval()

    return organ_model, chexnet_model

organ_model, chexnet_model = load_model()

# Image Preprocessing
def preprocess_image(image):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    return transform(image).unsqueeze(0)

# Groq API for AI Insights
def get_ai_insights(text_prompt):
    try:
        response = client.chat.completions.create(
            messages=[{"role": "user", "content": text_prompt}],
            model="llama-3.3-70b-versatile"
        )
        return response.choices[0].message.content
    except Exception as e:
        return f"Error: {e}"

# Predict Organ
def predict_organ(image):
    with torch.no_grad():
        output = organ_model(preprocess_image(image))
        classes = ["Lungs", "Heart", "Spine", "Other"]  # Example classes
        prediction = classes[output.argmax().item()]
    return prediction

# Predict Normal/Abnormal
def predict_normal_abnormal(image):
    with torch.no_grad():
        output = chexnet_model(preprocess_image(image))
        classes = ["Normal", "Abnormal"]
        prediction = classes[output.argmax().item()]
    return prediction

# Streamlit App
st.title("Medical X-ray Analysis App")
st.sidebar.title("Navigation")
task = st.sidebar.radio("Select a task", ["Upload X-ray", "AI Insights"])

if task == "Upload X-ray":
    uploaded_file = st.file_uploader("Upload an X-ray image", type=["jpg", "png", "jpeg"])

    if uploaded_file:
        image = Image.open(uploaded_file)
        st.image(image, caption="Uploaded X-ray", use_column_width=True)

        # Predict Organ
        st.subheader("Step 1: Identify the Organ")
        organ = predict_organ(image)
        st.write(f"Predicted Organ: **{organ}**")

        # Predict Normal/Abnormal
        st.subheader("Step 2: Analyze the X-ray")
        classification = predict_normal_abnormal(image)
        st.write(f"X-ray Status: **{classification}**")

        if classification == "Abnormal":
            st.subheader("Step 3: AI-Based Insights")
            ai_prompt = f"Explain why this X-ray of the {organ} is abnormal."
            insights = get_ai_insights(ai_prompt)
            st.write(insights)

elif task == "AI Insights":
    st.subheader("Ask AI")
    user_input = st.text_area("Enter your query for AI insights")
    if user_input:
        response = get_ai_insights(user_input)
        st.write(response)