Spaces:
Sleeping
Sleeping
File size: 3,701 Bytes
8a359f9 4265985 8a359f9 4c0ed57 99d610b 4c0ed57 99d610b 4c0ed57 8a359f9 99d610b f68cc3e 4c0ed57 f68cc3e 99d610b f68cc3e 4c0ed57 78f0c67 4c0ed57 99d610b f68cc3e 4c0ed57 f68cc3e 99d610b 4c0ed57 99d610b 4c0ed57 4265985 4c0ed57 4265985 4c0ed57 99d610b 4c0ed57 4265985 99d610b 4c0ed57 99d610b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import os
import streamlit as st
from PIL import Image
import torch
from torchvision import transforms, models
import numpy as np
from groq import Groq
# Set up environment variables
os.environ["GROQ_API_KEY"] = "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"
# Initialize Groq client
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
# Load Pretrained Models
@st.cache_resource
# Load Pretrained Model for Organ Recognition
@st.cache_resource
def load_organ_model():
model = models.resnet18(pretrained=True) # Load pretrained ResNet18
num_features = model.fc.in_features # Get the number of input features to the final layer
model.fc = torch.nn.Linear(num_features, 4) # Modify the final layer for 4 classes
model.eval() # Set the model to evaluation mode
return model
# Image Preprocessing
def preprocess_image(image):
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
return transform(image).unsqueeze(0)
# Groq API for AI Insights
def get_ai_insights(text_prompt):
try:
response = client.chat.completions.create(
messages=[{"role": "user", "content": text_prompt}],
model="llama-3.3-70b-versatile"
)
return response.choices[0].message.content
except Exception as e:
return f"Error: {e}"
# Organ Recognition Prediction
def predict_organ(image):
with torch.no_grad():
input_tensor = preprocess_image(image)
output = organ_model(input_tensor)
# Check the output dimensions
st.write(f"Model output shape: {output.shape}")
# Ensure the output matches the number of classes
classes = ["Lungs", "Heart", "Spine", "Other"]
if output.size(1) != len(classes):
raise ValueError(
f"Model output size ({output.size(1)}) does not match the number of classes ({len(classes)})."
)
# Get the prediction
prediction_index = output.argmax().item()
prediction = classes[prediction_index]
return prediction
# Predict Normal/Abnormal
def predict_normal_abnormal(image):
with torch.no_grad():
output = chexnet_model(preprocess_image(image))
classes = ["Normal", "Abnormal"]
prediction = classes[output.argmax().item()]
return prediction
# Streamlit App
st.title("Medical X-ray Analysis App")
st.sidebar.title("Navigation")
task = st.sidebar.radio("Select a task", ["Upload X-ray", "AI Insights"])
if task == "Upload X-ray":
uploaded_file = st.file_uploader("Upload an X-ray image", type=["jpg", "png", "jpeg"])
if uploaded_file:
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded X-ray", use_column_width=True)
# Predict Organ
st.subheader("Step 1: Identify the Organ")
organ = predict_organ(image)
st.write(f"Predicted Organ: **{organ}**")
# Predict Normal/Abnormal
st.subheader("Step 2: Analyze the X-ray")
classification = predict_normal_abnormal(image)
st.write(f"X-ray Status: **{classification}**")
if classification == "Abnormal":
st.subheader("Step 3: AI-Based Insights")
ai_prompt = f"Explain why this X-ray of the {organ} is abnormal."
insights = get_ai_insights(ai_prompt)
st.write(insights)
elif task == "AI Insights":
st.subheader("Ask AI")
user_input = st.text_area("Enter your query for AI insights")
if user_input:
response = get_ai_insights(user_input)
st.write(response)
|