abuzarAli commited on
Commit
99d610b
·
verified ·
1 Parent(s): f68cc3e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -9
app.py CHANGED
@@ -3,19 +3,27 @@ import streamlit as st
3
  from PIL import Image
4
  import torch
5
  from torchvision import transforms, models
 
 
6
 
7
- # Set up environment variable for Groq API
8
  os.environ["GROQ_API_KEY"] = "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"
9
 
 
 
 
 
 
10
  # Load Pretrained Model for Organ Recognition
11
  @st.cache_resource
12
  def load_organ_model():
13
- model = models.resnet18(pretrained=True) # ResNet18 pretrained model
14
- model.fc = torch.nn.Linear(model.fc.in_features, 4) # Modify for 4 classes
15
- model.eval()
 
16
  return model
17
 
18
- organ_model = load_organ_model()
19
 
20
  # Image Preprocessing
21
  def preprocess_image(image):
@@ -26,17 +34,49 @@ def preprocess_image(image):
26
  ])
27
  return transform(image).unsqueeze(0)
28
 
 
 
 
 
 
 
 
 
 
 
 
29
  # Organ Recognition Prediction
30
  def predict_organ(image):
31
  with torch.no_grad():
32
  input_tensor = preprocess_image(image)
33
  output = organ_model(input_tensor)
34
- classes = ["Lungs", "Heart", "Spine", "Other"] # Example organ classes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  prediction = classes[output.argmax().item()]
36
  return prediction
37
 
38
  # Streamlit App
39
- st.title("X-ray Organ Recognition App")
40
  st.sidebar.title("Navigation")
41
  task = st.sidebar.radio("Select a task", ["Upload X-ray", "AI Insights"])
42
 
@@ -48,12 +88,24 @@ if task == "Upload X-ray":
48
  st.image(image, caption="Uploaded X-ray", use_column_width=True)
49
 
50
  # Predict Organ
51
- st.subheader("Step 1: Identify the Organ in the X-ray")
52
  organ = predict_organ(image)
53
  st.write(f"Predicted Organ: **{organ}**")
54
 
 
 
 
 
 
 
 
 
 
 
 
55
  elif task == "AI Insights":
56
  st.subheader("Ask AI")
57
  user_input = st.text_area("Enter your query for AI insights")
58
  if user_input:
59
- st.write("AI insights will be generated here.") # Placeholder for AI logic
 
 
3
  from PIL import Image
4
  import torch
5
  from torchvision import transforms, models
6
+ import numpy as np
7
+ from groq import Groq
8
 
9
+ # Set up environment variables
10
  os.environ["GROQ_API_KEY"] = "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"
11
 
12
+ # Initialize Groq client
13
+ client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
14
+
15
+ # Load Pretrained Models
16
+ @st.cache_resource
17
  # Load Pretrained Model for Organ Recognition
18
  @st.cache_resource
19
  def load_organ_model():
20
+ model = models.resnet18(pretrained=True) # Load pretrained ResNet18
21
+ num_features = model.fc.in_features # Get the number of input features to the final layer
22
+ model.fc = torch.nn.Linear(num_features, 4) # Modify the final layer for 4 classes
23
+ model.eval() # Set the model to evaluation mode
24
  return model
25
 
26
+ organ_model, chexnet_model = load_model()
27
 
28
  # Image Preprocessing
29
  def preprocess_image(image):
 
34
  ])
35
  return transform(image).unsqueeze(0)
36
 
37
+ # Groq API for AI Insights
38
+ def get_ai_insights(text_prompt):
39
+ try:
40
+ response = client.chat.completions.create(
41
+ messages=[{"role": "user", "content": text_prompt}],
42
+ model="llama-3.3-70b-versatile"
43
+ )
44
+ return response.choices[0].message.content
45
+ except Exception as e:
46
+ return f"Error: {e}"
47
+
48
  # Organ Recognition Prediction
49
  def predict_organ(image):
50
  with torch.no_grad():
51
  input_tensor = preprocess_image(image)
52
  output = organ_model(input_tensor)
53
+
54
+ # Check the output dimensions
55
+ st.write(f"Model output shape: {output.shape}")
56
+
57
+ # Ensure the output matches the number of classes
58
+ classes = ["Lungs", "Heart", "Spine", "Other"]
59
+ if output.size(1) != len(classes):
60
+ raise ValueError(
61
+ f"Model output size ({output.size(1)}) does not match the number of classes ({len(classes)})."
62
+ )
63
+
64
+ # Get the prediction
65
+ prediction_index = output.argmax().item()
66
+ prediction = classes[prediction_index]
67
+ return prediction
68
+
69
+
70
+ # Predict Normal/Abnormal
71
+ def predict_normal_abnormal(image):
72
+ with torch.no_grad():
73
+ output = chexnet_model(preprocess_image(image))
74
+ classes = ["Normal", "Abnormal"]
75
  prediction = classes[output.argmax().item()]
76
  return prediction
77
 
78
  # Streamlit App
79
+ st.title("Medical X-ray Analysis App")
80
  st.sidebar.title("Navigation")
81
  task = st.sidebar.radio("Select a task", ["Upload X-ray", "AI Insights"])
82
 
 
88
  st.image(image, caption="Uploaded X-ray", use_column_width=True)
89
 
90
  # Predict Organ
91
+ st.subheader("Step 1: Identify the Organ")
92
  organ = predict_organ(image)
93
  st.write(f"Predicted Organ: **{organ}**")
94
 
95
+ # Predict Normal/Abnormal
96
+ st.subheader("Step 2: Analyze the X-ray")
97
+ classification = predict_normal_abnormal(image)
98
+ st.write(f"X-ray Status: **{classification}**")
99
+
100
+ if classification == "Abnormal":
101
+ st.subheader("Step 3: AI-Based Insights")
102
+ ai_prompt = f"Explain why this X-ray of the {organ} is abnormal."
103
+ insights = get_ai_insights(ai_prompt)
104
+ st.write(insights)
105
+
106
  elif task == "AI Insights":
107
  st.subheader("Ask AI")
108
  user_input = st.text_area("Enter your query for AI insights")
109
  if user_input:
110
+ response = get_ai_insights(user_input)
111
+ st.write(response)