File size: 3,701 Bytes
8a359f9
4265985
8a359f9
4c0ed57
 
99d610b
 
4c0ed57
99d610b
4c0ed57
8a359f9
99d610b
 
 
 
 
f68cc3e
4c0ed57
f68cc3e
99d610b
 
 
 
f68cc3e
4c0ed57
78f0c67
4c0ed57
 
 
 
 
 
 
 
 
 
99d610b
 
 
 
 
 
 
 
 
 
 
f68cc3e
4c0ed57
 
f68cc3e
 
99d610b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c0ed57
 
 
 
99d610b
4c0ed57
 
 
 
 
4265985
 
 
4c0ed57
4265985
4c0ed57
99d610b
4c0ed57
 
4265985
99d610b
 
 
 
 
 
 
 
 
 
 
4c0ed57
 
 
 
99d610b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import streamlit as st
from PIL import Image
import torch
from torchvision import transforms, models
import numpy as np
from groq import Groq

# Set up environment variables
os.environ["GROQ_API_KEY"] = "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"

# Initialize Groq client
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))

# Load Pretrained Models
@st.cache_resource
# Load Pretrained Model for Organ Recognition
@st.cache_resource
def load_organ_model():
    model = models.resnet18(pretrained=True)  # Load pretrained ResNet18
    num_features = model.fc.in_features       # Get the number of input features to the final layer
    model.fc = torch.nn.Linear(num_features, 4)  # Modify the final layer for 4 classes
    model.eval()  # Set the model to evaluation mode
    return model



# Image Preprocessing
def preprocess_image(image):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    return transform(image).unsqueeze(0)

# Groq API for AI Insights
def get_ai_insights(text_prompt):
    try:
        response = client.chat.completions.create(
            messages=[{"role": "user", "content": text_prompt}],
            model="llama-3.3-70b-versatile"
        )
        return response.choices[0].message.content
    except Exception as e:
        return f"Error: {e}"
        
# Organ Recognition Prediction
def predict_organ(image):
    with torch.no_grad():
        input_tensor = preprocess_image(image)
        output = organ_model(input_tensor)

        # Check the output dimensions
        st.write(f"Model output shape: {output.shape}")

        # Ensure the output matches the number of classes
        classes = ["Lungs", "Heart", "Spine", "Other"]
        if output.size(1) != len(classes):
            raise ValueError(
                f"Model output size ({output.size(1)}) does not match the number of classes ({len(classes)})."
            )

        # Get the prediction
        prediction_index = output.argmax().item()
        prediction = classes[prediction_index]
        return prediction


# Predict Normal/Abnormal
def predict_normal_abnormal(image):
    with torch.no_grad():
        output = chexnet_model(preprocess_image(image))
        classes = ["Normal", "Abnormal"]
        prediction = classes[output.argmax().item()]
    return prediction

# Streamlit App
st.title("Medical X-ray Analysis App")
st.sidebar.title("Navigation")
task = st.sidebar.radio("Select a task", ["Upload X-ray", "AI Insights"])

if task == "Upload X-ray":
    uploaded_file = st.file_uploader("Upload an X-ray image", type=["jpg", "png", "jpeg"])

    if uploaded_file:
        image = Image.open(uploaded_file)
        st.image(image, caption="Uploaded X-ray", use_column_width=True)

        # Predict Organ
        st.subheader("Step 1: Identify the Organ")
        organ = predict_organ(image)
        st.write(f"Predicted Organ: **{organ}**")

        # Predict Normal/Abnormal
        st.subheader("Step 2: Analyze the X-ray")
        classification = predict_normal_abnormal(image)
        st.write(f"X-ray Status: **{classification}**")

        if classification == "Abnormal":
            st.subheader("Step 3: AI-Based Insights")
            ai_prompt = f"Explain why this X-ray of the {organ} is abnormal."
            insights = get_ai_insights(ai_prompt)
            st.write(insights)

elif task == "AI Insights":
    st.subheader("Ask AI")
    user_input = st.text_area("Enter your query for AI insights")
    if user_input:
        response = get_ai_insights(user_input)
        st.write(response)