abuzarAli commited on
Commit
f68cc3e
·
verified ·
1 Parent(s): 45f426d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -56
app.py CHANGED
@@ -3,30 +3,19 @@ import streamlit as st
3
  from PIL import Image
4
  import torch
5
  from torchvision import transforms, models
6
- import numpy as np
7
- from groq import Groq
8
 
9
- # Set up environment variables
10
  os.environ["GROQ_API_KEY"] = "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"
11
 
12
- # Initialize Groq client
13
- client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
14
-
15
- # Load Pretrained Models
16
  @st.cache_resource
17
- def load_model():
18
- # Pretrained EfficientNet for organ recognition
19
- organ_model = models.efficientnet_b0(pretrained=True)
20
- organ_model.eval()
21
-
22
- # Pretrained DenseNet (CheXNet) for normal/abnormal classification
23
- chexnet_model = models.densenet121(pretrained=True)
24
- chexnet_model.classifier = torch.nn.Linear(1024, 2) # Normal, Abnormal
25
- chexnet_model.eval()
26
-
27
- return organ_model, chexnet_model
28
 
29
- organ_model, chexnet_model = load_model()
30
 
31
  # Image Preprocessing
32
  def preprocess_image(image):
@@ -37,35 +26,17 @@ def preprocess_image(image):
37
  ])
38
  return transform(image).unsqueeze(0)
39
 
40
- # Groq API for AI Insights
41
- def get_ai_insights(text_prompt):
42
- try:
43
- response = client.chat.completions.create(
44
- messages=[{"role": "user", "content": text_prompt}],
45
- model="llama-3.3-70b-versatile"
46
- )
47
- return response.choices[0].message.content
48
- except Exception as e:
49
- return f"Error: {e}"
50
-
51
- # Predict Organ
52
  def predict_organ(image):
53
  with torch.no_grad():
54
- output = organ_model(preprocess_image(image))
55
- classes = ["Lungs", "Heart", "Spine", "Other"] # Example classes
56
- prediction = classes[output.argmax().item()]
57
- return prediction
58
-
59
- # Predict Normal/Abnormal
60
- def predict_normal_abnormal(image):
61
- with torch.no_grad():
62
- output = chexnet_model(preprocess_image(image))
63
- classes = ["Normal", "Abnormal"]
64
  prediction = classes[output.argmax().item()]
65
  return prediction
66
 
67
  # Streamlit App
68
- st.title("Medical X-ray Analysis App")
69
  st.sidebar.title("Navigation")
70
  task = st.sidebar.radio("Select a task", ["Upload X-ray", "AI Insights"])
71
 
@@ -77,24 +48,12 @@ if task == "Upload X-ray":
77
  st.image(image, caption="Uploaded X-ray", use_column_width=True)
78
 
79
  # Predict Organ
80
- st.subheader("Step 1: Identify the Organ")
81
  organ = predict_organ(image)
82
  st.write(f"Predicted Organ: **{organ}**")
83
 
84
- # Predict Normal/Abnormal
85
- st.subheader("Step 2: Analyze the X-ray")
86
- classification = predict_normal_abnormal(image)
87
- st.write(f"X-ray Status: **{classification}**")
88
-
89
- if classification == "Abnormal":
90
- st.subheader("Step 3: AI-Based Insights")
91
- ai_prompt = f"Explain why this X-ray of the {organ} is abnormal."
92
- insights = get_ai_insights(ai_prompt)
93
- st.write(insights)
94
-
95
  elif task == "AI Insights":
96
  st.subheader("Ask AI")
97
  user_input = st.text_area("Enter your query for AI insights")
98
  if user_input:
99
- response = get_ai_insights(user_input)
100
- st.write(response)
 
3
  from PIL import Image
4
  import torch
5
  from torchvision import transforms, models
 
 
6
 
7
+ # Set up environment variable for Groq API
8
  os.environ["GROQ_API_KEY"] = "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"
9
 
10
+ # Load Pretrained Model for Organ Recognition
 
 
 
11
  @st.cache_resource
12
+ def load_organ_model():
13
+ model = models.resnet18(pretrained=True) # ResNet18 pretrained model
14
+ model.fc = torch.nn.Linear(model.fc.in_features, 4) # Modify for 4 classes
15
+ model.eval()
16
+ return model
 
 
 
 
 
 
17
 
18
+ organ_model = load_organ_model()
19
 
20
  # Image Preprocessing
21
  def preprocess_image(image):
 
26
  ])
27
  return transform(image).unsqueeze(0)
28
 
29
+ # Organ Recognition Prediction
 
 
 
 
 
 
 
 
 
 
 
30
  def predict_organ(image):
31
  with torch.no_grad():
32
+ input_tensor = preprocess_image(image)
33
+ output = organ_model(input_tensor)
34
+ classes = ["Lungs", "Heart", "Spine", "Other"] # Example organ classes
 
 
 
 
 
 
 
35
  prediction = classes[output.argmax().item()]
36
  return prediction
37
 
38
  # Streamlit App
39
+ st.title("X-ray Organ Recognition App")
40
  st.sidebar.title("Navigation")
41
  task = st.sidebar.radio("Select a task", ["Upload X-ray", "AI Insights"])
42
 
 
48
  st.image(image, caption="Uploaded X-ray", use_column_width=True)
49
 
50
  # Predict Organ
51
+ st.subheader("Step 1: Identify the Organ in the X-ray")
52
  organ = predict_organ(image)
53
  st.write(f"Predicted Organ: **{organ}**")
54
 
 
 
 
 
 
 
 
 
 
 
 
55
  elif task == "AI Insights":
56
  st.subheader("Ask AI")
57
  user_input = st.text_area("Enter your query for AI insights")
58
  if user_input:
59
+ st.write("AI insights will be generated here.") # Placeholder for AI logic