abuzarAli commited on
Commit
4c0ed57
·
verified ·
1 Parent(s): 4265985

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -72
app.py CHANGED
@@ -1,84 +1,100 @@
1
  import os
2
  import streamlit as st
3
- import tensorflow as tf
4
- from tensorflow.keras.applications import ResNet50
5
- from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions
6
- from tensorflow.keras.models import Model
7
- from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
8
- from tensorflow.keras.preprocessing.image import img_to_array, load_img
9
- import numpy as np
10
- import requests
11
  from PIL import Image
 
 
 
 
 
 
 
12
 
13
- # Set Groq API key in environment variable
14
- os.environ['GROQ_API_KEY'] = "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"
15
- GROQ_API_KEY = os.getenv('GROQ_API_KEY')
16
-
17
- # Load pre-trained ResNet50 for normal/abnormal classification
18
- base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
19
- x = base_model.output
20
- x = GlobalAveragePooling2D()(x)
21
- x = Dense(1024, activation='relu')(x)
22
- predictions = Dense(1, activation='sigmoid')(x)
23
- classification_model = Model(inputs=base_model.input, outputs=predictions)
24
-
25
- # Load pre-trained ResNet50 for organ recognition
26
- organ_model = ResNet50(weights='imagenet')
27
-
28
- def classify_image(image_path):
29
- """Classify the image as normal or abnormal."""
30
- image = load_img(image_path, target_size=(224, 224))
31
- image_array = img_to_array(image)
32
- image_array = preprocess_input(image_array)
33
- image_array = np.expand_dims(image_array, axis=0)
34
- prediction = classification_model.predict(image_array)
35
- return 'Abnormal' if prediction[0][0] > 0.5 else 'Normal'
36
-
37
- def recognize_organ(image_path):
38
- """Recognize the organ in the image."""
39
- image = load_img(image_path, target_size=(224, 224))
40
- image_array = img_to_array(image)
41
- image_array = preprocess_input(image_array)
42
- image_array = np.expand_dims(image_array, axis=0)
43
- prediction = organ_model.predict(image_array)
44
- decoded = decode_predictions(prediction, top=3)[0]
45
- return decoded[0][1] # Top predicted class
46
-
47
- def get_ai_insights(organ):
48
- """Fetch AI-based insights about the organ using Groq API."""
49
- url = "https://api.groq.com/v1/insights"
50
- headers = {"Authorization": f"Bearer {GROQ_API_KEY}", "Content-Type": "application/json"}
51
- data = {"query": f"Provide detailed insights about {organ} X-ray, its diseases, and treatments."}
52
- response = requests.post(url, headers=headers, json=data)
53
- if response.status_code == 200:
54
- return response.json().get("insights", "No insights available.")
55
- else:
56
- return "Failed to fetch insights. Please try again later."
57
-
58
- def main():
59
- st.title("Medical Image Classification App")
60
- st.sidebar.title("Navigation")
61
 
62
- uploaded_file = st.file_uploader("Upload an X-ray or MRI image", type=["jpg", "jpeg", "png"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  if uploaded_file:
65
  image = Image.open(uploaded_file)
66
- st.image(image, caption="Uploaded Image", use_column_width=True)
67
-
68
- with open("temp_image.jpg", "wb") as f:
69
- f.write(uploaded_file.getbuffer())
70
 
71
- st.write("### Classification Result")
72
- result = classify_image("temp_image.jpg")
73
- st.write(f"The X-ray is classified as: **{result}**")
 
74
 
75
- st.write("### Organ Recognition")
76
- organ = recognize_organ("temp_image.jpg")
77
- st.write(f"Recognized Organ: **{organ}**")
 
78
 
79
- st.write("### AI-Based Insights")
80
- insights = get_ai_insights(organ)
81
- st.write(insights)
 
 
82
 
83
- if __name__ == "__main__":
84
- main()
 
 
 
 
 
1
  import os
2
  import streamlit as st
 
 
 
 
 
 
 
 
3
  from PIL import Image
4
+ import torch
5
+ from torchvision import transforms, models
6
+ import numpy as np
7
+ from groq import Groq
8
+
9
+ # Set up environment variables
10
+ os.environ["GROQ_API_KEY"] = "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"
11
 
12
+ # Initialize Groq client
13
+ client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
14
+
15
+ # Load Pretrained Models
16
+ @st.cache_resource
17
+ def load_model():
18
+ # Pretrained EfficientNet for organ recognition
19
+ organ_model = models.efficientnet_b0(pretrained=True)
20
+ organ_model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ # Pretrained DenseNet (CheXNet) for normal/abnormal classification
23
+ chexnet_model = models.densenet121(pretrained=True)
24
+ chexnet_model.classifier = torch.nn.Linear(1024, 2) # Normal, Abnormal
25
+ chexnet_model.eval()
26
+
27
+ return organ_model, chexnet_model
28
+
29
+ organ_model, chexnet_model = load_model()
30
+
31
+ # Image Preprocessing
32
+ def preprocess_image(image):
33
+ transform = transforms.Compose([
34
+ transforms.Resize((224, 224)),
35
+ transforms.ToTensor(),
36
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
37
+ ])
38
+ return transform(image).unsqueeze(0)
39
+
40
+ # Groq API for AI Insights
41
+ def get_ai_insights(text_prompt):
42
+ try:
43
+ response = client.chat.completions.create(
44
+ messages=[{"role": "user", "content": text_prompt}],
45
+ model="llama-3.3-70b-versatile"
46
+ )
47
+ return response.choices[0].message.content
48
+ except Exception as e:
49
+ return f"Error: {e}"
50
+
51
+ # Predict Organ
52
+ def predict_organ(image):
53
+ with torch.no_grad():
54
+ output = organ_model(preprocess_image(image))
55
+ classes = ["Lungs", "Heart", "Spine", "Other"] # Example classes
56
+ prediction = classes[output.argmax().item()]
57
+ return prediction
58
+
59
+ # Predict Normal/Abnormal
60
+ def predict_normal_abnormal(image):
61
+ with torch.no_grad():
62
+ output = chexnet_model(preprocess_image(image))
63
+ classes = ["Normal", "Abnormal"]
64
+ prediction = classes[output.argmax().item()]
65
+ return prediction
66
+
67
+ # Streamlit App
68
+ st.title("Medical X-ray Analysis App")
69
+ st.sidebar.title("Navigation")
70
+ task = st.sidebar.radio("Select a task", ["Upload X-ray", "AI Insights"])
71
+
72
+ if task == "Upload X-ray":
73
+ uploaded_file = st.file_uploader("Upload an X-ray image", type=["jpg", "png", "jpeg"])
74
 
75
  if uploaded_file:
76
  image = Image.open(uploaded_file)
77
+ st.image(image, caption="Uploaded X-ray", use_column_width=True)
 
 
 
78
 
79
+ # Predict Organ
80
+ st.subheader("Step 1: Identify the Organ")
81
+ organ = predict_organ(image)
82
+ st.write(f"Predicted Organ: **{organ}**")
83
 
84
+ # Predict Normal/Abnormal
85
+ st.subheader("Step 2: Analyze the X-ray")
86
+ classification = predict_normal_abnormal(image)
87
+ st.write(f"X-ray Status: **{classification}**")
88
 
89
+ if classification == "Abnormal":
90
+ st.subheader("Step 3: AI-Based Insights")
91
+ ai_prompt = f"Explain why this X-ray of the {organ} is abnormal."
92
+ insights = get_ai_insights(ai_prompt)
93
+ st.write(insights)
94
 
95
+ elif task == "AI Insights":
96
+ st.subheader("Ask AI")
97
+ user_input = st.text_area("Enter your query for AI insights")
98
+ if user_input:
99
+ response = get_ai_insights(user_input)
100
+ st.write(response)