abuzarAli commited on
Commit
cf8fb5c
·
verified ·
1 Parent(s): 78f0c67

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -101
app.py CHANGED
@@ -1,111 +1,74 @@
1
  import os
2
  import streamlit as st
 
3
  from PIL import Image
4
- import torch
5
- from torchvision import transforms, models
6
- import numpy as np
7
- from groq import Groq
8
 
9
- # Set up environment variables
10
- os.environ["GROQ_API_KEY"] = "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"
11
-
12
- # Initialize Groq client
13
- client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
14
-
15
- # Load Pretrained Models
16
- @st.cache_resource
17
- # Load Pretrained Model for Organ Recognition
18
- @st.cache_resource
19
- def load_organ_model():
20
- model = models.resnet18(pretrained=True) # Load pretrained ResNet18
21
- num_features = model.fc.in_features # Get the number of input features to the final layer
22
- model.fc = torch.nn.Linear(num_features, 4) # Modify the final layer for 4 classes
23
- model.eval() # Set the model to evaluation mode
24
- return model
25
-
26
-
27
-
28
- # Image Preprocessing
29
- def preprocess_image(image):
30
- transform = transforms.Compose([
31
- transforms.Resize((224, 224)),
32
- transforms.ToTensor(),
33
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
34
- ])
35
- return transform(image).unsqueeze(0)
36
-
37
- # Groq API for AI Insights
38
- def get_ai_insights(text_prompt):
39
  try:
40
- response = client.chat.completions.create(
41
- messages=[{"role": "user", "content": text_prompt}],
42
- model="llama-3.3-70b-versatile"
43
- )
44
- return response.choices[0].message.content
45
  except Exception as e:
46
- return f"Error: {e}"
47
-
48
- # Organ Recognition Prediction
49
- def predict_organ(image):
50
- with torch.no_grad():
51
- input_tensor = preprocess_image(image)
52
- output = organ_model(input_tensor)
53
-
54
- # Check the output dimensions
55
- st.write(f"Model output shape: {output.shape}")
56
-
57
- # Ensure the output matches the number of classes
58
- classes = ["Lungs", "Heart", "Spine", "Other"]
59
- if output.size(1) != len(classes):
60
- raise ValueError(
61
- f"Model output size ({output.size(1)}) does not match the number of classes ({len(classes)})."
62
- )
63
-
64
- # Get the prediction
65
- prediction_index = output.argmax().item()
66
- prediction = classes[prediction_index]
67
- return prediction
68
 
69
-
70
- # Predict Normal/Abnormal
71
- def predict_normal_abnormal(image):
72
- with torch.no_grad():
73
- output = chexnet_model(preprocess_image(image))
74
- classes = ["Normal", "Abnormal"]
75
- prediction = classes[output.argmax().item()]
76
- return prediction
77
 
78
  # Streamlit App
79
- st.title("Medical X-ray Analysis App")
80
- st.sidebar.title("Navigation")
81
- task = st.sidebar.radio("Select a task", ["Upload X-ray", "AI Insights"])
82
-
83
- if task == "Upload X-ray":
84
- uploaded_file = st.file_uploader("Upload an X-ray image", type=["jpg", "png", "jpeg"])
85
-
86
- if uploaded_file:
87
- image = Image.open(uploaded_file)
88
- st.image(image, caption="Uploaded X-ray", use_column_width=True)
89
-
90
- # Predict Organ
91
- st.subheader("Step 1: Identify the Organ")
92
- organ = predict_organ(image)
93
- st.write(f"Predicted Organ: **{organ}**")
94
-
95
- # Predict Normal/Abnormal
96
- st.subheader("Step 2: Analyze the X-ray")
97
- classification = predict_normal_abnormal(image)
98
- st.write(f"X-ray Status: **{classification}**")
99
-
100
- if classification == "Abnormal":
101
- st.subheader("Step 3: AI-Based Insights")
102
- ai_prompt = f"Explain why this X-ray of the {organ} is abnormal."
103
- insights = get_ai_insights(ai_prompt)
104
- st.write(insights)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
- elif task == "AI Insights":
107
- st.subheader("Ask AI")
108
- user_input = st.text_area("Enter your query for AI insights")
109
- if user_input:
110
- response = get_ai_insights(user_input)
111
- st.write(response)
 
1
  import os
2
  import streamlit as st
3
+ from transformers import pipeline, AutoImageProcessor, AutoModelForImageClassification
4
  from PIL import Image
 
 
 
 
5
 
6
+ def load_pipeline():
7
+ """Load the Hugging Face pipeline for image classification."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  try:
9
+ return pipeline("image-classification", model="dima806/pneumonia_chest_xray_image_detection")
 
 
 
 
10
  except Exception as e:
11
+ st.error(f"Error loading pipeline: {e}")
12
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ def classify_image_with_pipeline(pipe, image):
15
+ """Classify an image using the pipeline."""
16
+ try:
17
+ results = pipe(image)
18
+ return results
19
+ except Exception as e:
20
+ st.error(f"Error classifying image: {e}")
21
+ return None
22
 
23
  # Streamlit App
24
+ st.title("Pneumonia Chest X-ray Image Detection")
25
+ st.markdown(
26
+ """
27
+ This app detects signs of pneumonia in chest X-ray images using a pre-trained Hugging Face model.
28
+ """
29
+ )
30
+
31
+ # File uploader
32
+ uploaded_file = st.file_uploader("Upload a chest X-ray image", type=["jpg", "jpeg", "png"])
33
+
34
+ if uploaded_file:
35
+ image = Image.open(uploaded_file)
36
+ st.image(image, caption="Uploaded Chest X-ray", use_column_width=True)
37
+
38
+ # Load the model pipeline
39
+ pipe = load_pipeline()
40
+
41
+ if pipe:
42
+ st.write("Classifying the image...")
43
+ results = classify_image_with_pipeline(pipe, image)
44
+
45
+ if results:
46
+ st.write("### Classification Results:")
47
+ for result in results:
48
+ st.write(f"**Label:** {result['label']} | **Score:** {result['score']:.4f}")
49
+
50
+ # Optional: Add Groq API integration if applicable
51
+ if os.getenv("GROQ_API_KEY"):
52
+ from groq import Groq
53
+
54
+ client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
55
+
56
+ st.sidebar.markdown("### Groq API Integration")
57
+ question = st.sidebar.text_input("Ask a question about pneumonia or X-ray diagnosis:")
58
+
59
+ if question:
60
+ try:
61
+ chat_completion = client.chat.completions.create(
62
+ messages=[
63
+ {
64
+ "role": "user",
65
+ "content": question,
66
+ }
67
+ ],
68
+ model="llama-3.3-70b-versatile",
69
+ )
70
 
71
+ st.sidebar.write("**Groq API Response:**")
72
+ st.sidebar.write(chat_completion.choices[0].message.content)
73
+ except Exception as e:
74
+ st.sidebar.error(f"Error using Groq API: {e}")