mgbam commited on
Commit
3300549
·
verified ·
1 Parent(s): 2d10123

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -168
app.py CHANGED
@@ -1,168 +1,118 @@
1
- import os
2
- import traceback
3
- import numpy as np
4
- import streamlit as st
5
- from PIL import Image
6
- from transformers import pipeline
7
- import matplotlib.pyplot as plt
8
- from skimage.color import rgb2gray
9
- from skimage.filters import threshold_otsu
10
-
11
-
12
- # =======================
13
- # Configuration and Setup
14
- # =======================
15
-
16
- # Streamlit Page Configuration
17
- st.set_page_config(
18
- page_title="AI Cancer Detection Platform",
19
- page_icon="🩺",
20
- layout="wide",
21
- initial_sidebar_state="expanded",
22
- menu_items={
23
- "About": "### AI Cancer Detection Platform\n"
24
- "Developed to classify cancer images and provide research insights."
25
- }
26
- )
27
-
28
-
29
- # =======================
30
- # Helper Functions
31
- # =======================
32
-
33
- @st.cache_resource
34
- def load_pipeline():
35
- """
36
- Load the pre-trained image classification pipeline using PyTorch as the backend.
37
- """
38
- try:
39
- model_pipeline = pipeline(
40
- "image-classification",
41
- model="Anwarkh1/Skin_Cancer-Image_Classification",
42
- framework="pt" # Force PyTorch backend
43
- )
44
- return model_pipeline
45
- except Exception as e:
46
- st.error(f"Error loading model: {e}")
47
- traceback.print_exc()
48
- st.stop()
49
-
50
-
51
- def process_image(image):
52
- """
53
- Perform image processing to extract features for better visualization.
54
- """
55
- try:
56
- # Convert image to grayscale
57
- gray_image = rgb2gray(np.array(image))
58
-
59
- # Apply Otsu's threshold
60
- thresh = threshold_otsu(gray_image)
61
- binary = gray_image > thresh
62
-
63
- # Calculate edge pixel percentage
64
- edge_pixels = np.sum(binary)
65
- total_pixels = binary.size
66
- edge_percentage = (edge_pixels / total_pixels) * 100
67
-
68
- # Generate plots
69
- fig, ax = plt.subplots(1, 2, figsize=(10, 5))
70
- ax[0].imshow(gray_image, cmap="gray")
71
- ax[0].set_title("Grayscale Image")
72
- ax[0].axis("off")
73
-
74
- ax[1].imshow(binary, cmap="gray")
75
- ax[1].set_title("Binary Image (Thresholded)")
76
- ax[1].axis("off")
77
-
78
- plt.tight_layout()
79
- st.pyplot(fig)
80
-
81
- # Feature description
82
- return f"{edge_percentage:.2f}% of the image contains edge pixels after thresholding."
83
-
84
- except Exception as e:
85
- st.error(f"Error processing image: {e}")
86
- traceback.print_exc()
87
- return "No significant features extracted."
88
-
89
-
90
- def classify_image(image, model_pipeline):
91
- """
92
- Classify the uploaded image using the pre-trained model pipeline.
93
- """
94
- try:
95
- # Resize image to 224x224 as required by the model
96
- image_resized = image.resize((224, 224))
97
- predictions = model_pipeline(image_resized)
98
-
99
- if predictions:
100
- top_prediction = predictions[0]
101
- label = top_prediction["label"]
102
- score = top_prediction["score"]
103
- return label, score
104
- else:
105
- st.warning("No predictions were made.")
106
- return None, None
107
- except Exception as e:
108
- st.error(f"Error during classification: {e}")
109
- traceback.print_exc()
110
- return None, None
111
-
112
-
113
- # =======================
114
- # Streamlit Main Content
115
- # =======================
116
-
117
- st.title("🩺 AI-Powered Cancer Detection")
118
-
119
- # Image Upload Section
120
- st.subheader("📤 Upload a Cancer Image")
121
- uploaded_image = st.file_uploader("Choose an image file...", type=["png", "jpg", "jpeg"])
122
-
123
- if uploaded_image is not None:
124
- try:
125
- # Open the uploaded image
126
- image = Image.open(uploaded_image).convert("RGB")
127
-
128
- # Display the uploaded image
129
- st.image(image, caption="Uploaded Image", use_column_width=True)
130
-
131
- # Process the image
132
- st.markdown("### 🛠️ Image Processing")
133
- processed_features = process_image(image)
134
-
135
- # Load the model pipeline
136
- st.markdown("### 🔍 Classifying the Image")
137
- model_pipeline = load_pipeline()
138
-
139
- # Classify the image
140
- with st.spinner("Classifying..."):
141
- label, confidence = classify_image(image, model_pipeline)
142
-
143
- if label and confidence:
144
- st.write(f"**Prediction:** {label}")
145
- st.write(f"**Confidence:** {confidence:.2%}")
146
-
147
- # Highlight prediction confidence
148
- if confidence > 0.80:
149
- st.success("High confidence in the prediction.")
150
- elif confidence > 0.50:
151
- st.warning("Moderate confidence in the prediction.")
152
- else:
153
- st.error("Low confidence in the prediction.")
154
-
155
- except Exception as e:
156
- st.error(f"An unexpected error occurred: {e}")
157
- traceback.print_exc()
158
- else:
159
- st.info("Upload an image to start the classification.")
160
-
161
- # =======================
162
- # Footer
163
- # =======================
164
-
165
- st.markdown("""
166
- ---
167
- **AI Cancer Detection Platform** | This application is for informational purposes only and is not intended for medical diagnosis.
168
- """)
 
1
+ import streamlit as st
2
+ from transformers import pipeline, AutoModelForImageClassification, AutoFeatureExtractor
3
+ from PIL import Image
4
+ import openai
5
+ import os
6
+ import torch
7
+
8
+ # =======================
9
+ # Streamlit Page Config (MUST BE FIRST)
10
+ # =======================
11
+ st.set_page_config(
12
+ page_title="AI-Powered Skin Cancer Detection",
13
+ page_icon="🩺",
14
+ layout="wide",
15
+ initial_sidebar_state="expanded"
16
+ )
17
+
18
+ # =======================
19
+ # OpenAI API Configuration
20
+ # =======================
21
+ openai.api_key = os.getenv("OPENAI_API_KEY", "your_openai_api_key_here")
22
+
23
+ # =======================
24
+ # Load Model with PyTorch
25
+ # =======================
26
+ @st.cache_resource
27
+ def load_model():
28
+ """
29
+ Load the pre-trained skin cancer classification model using PyTorch.
30
+ Use the AutoModelForImageClassification and AutoFeatureExtractor for explicit local caching.
31
+ """
32
+ try:
33
+ extractor = AutoFeatureExtractor.from_pretrained("Anwarkh1/Skin_Cancer-Image_Classification")
34
+ model = AutoModelForImageClassification.from_pretrained("Anwarkh1/Skin_Cancer-Image_Classification")
35
+ return pipeline("image-classification", model=model, feature_extractor=extractor, framework="pt")
36
+ except Exception as e:
37
+ st.error(f"Error loading the model: {e}")
38
+ return None
39
+
40
+ model = load_model()
41
+
42
+ # =======================
43
+ # OpenAI Explanation Function
44
+ # =======================
45
+ def generate_openai_explanation(label, confidence):
46
+ """
47
+ Generate a detailed explanation for the classification result using OpenAI's GPT model.
48
+ """
49
+ prompt = (
50
+ f"The AI model has classified an image of a skin lesion as **{label}** with a confidence of **{confidence:.2%}**.\n"
51
+ f"Explain what this classification means, including potential characteristics of this lesion type, "
52
+ f"what steps a patient should take next, and how the AI might have arrived at this conclusion. "
53
+ f"Use language that is easy for a non-medical audience to understand."
54
+ )
55
+ try:
56
+ response = openai.Completion.create(
57
+ model="text-davinci-003", # Replace with "gpt-4" if available
58
+ prompt=prompt,
59
+ max_tokens=300,
60
+ temperature=0.7
61
+ )
62
+ return response.choices[0].text.strip()
63
+ except Exception as e:
64
+ return f"Error generating explanation: {e}"
65
+
66
+ # =======================
67
+ # Streamlit App Title and Sidebar
68
+ # =======================
69
+ st.title("🔍 AI-Powered Skin Cancer Classification and Explanation")
70
+ st.write("Upload an image of a skin lesion, and the AI model will classify it and provide a detailed explanation.")
71
+
72
+ st.sidebar.info("""
73
+ **AI Cancer Detection Platform**
74
+ This application uses AI to classify skin lesions and generate detailed explanations for informational purposes.
75
+ It is not intended for medical diagnosis. Always consult a healthcare professional for medical advice.
76
+ """)
77
+
78
+ # =======================
79
+ # File Upload and Prediction
80
+ # =======================
81
+ uploaded_image = st.file_uploader("Upload a skin lesion image (PNG, JPG, JPEG)", type=["png", "jpg", "jpeg"])
82
+
83
+ if uploaded_image:
84
+ # Display uploaded image
85
+ image = Image.open(uploaded_image).convert("RGB")
86
+ st.image(image, caption="Uploaded Image", use_column_width=True)
87
+
88
+ # Perform classification
89
+ if model is None:
90
+ st.error("Model could not be loaded. Please try again later.")
91
+ else:
92
+ with st.spinner("Classifying the image..."):
93
+ try:
94
+ results = model(image)
95
+ label = results[0]['label']
96
+ confidence = results[0]['score']
97
+
98
+ # Display prediction results
99
+ st.markdown(f"### Prediction: **{label}**")
100
+ st.markdown(f"### Confidence: **{confidence:.2%}**")
101
+
102
+ # Provide confidence-based insights
103
+ if confidence >= 0.8:
104
+ st.success("High confidence in the prediction.")
105
+ elif confidence >= 0.5:
106
+ st.warning("Moderate confidence in the prediction. Consider additional verification.")
107
+ else:
108
+ st.error("Low confidence in the prediction. Results should be interpreted with caution.")
109
+
110
+ # Generate explanation
111
+ with st.spinner("Generating a detailed explanation..."):
112
+ explanation = generate_openai_explanation(label, confidence)
113
+
114
+ st.markdown("### Explanation")
115
+ st.write(explanation)
116
+
117
+ except Exception as e:
118
+ st.error(f"Error during classification: {e}")