mgbam's picture
Update app.py
5b2305a verified
raw
history blame
4.9 kB
import streamlit as st
from transformers import pipeline, AutoModelForImageClassification, AutoFeatureExtractor
from PIL import Image
import openai
import os
from dotenv import load_dotenv
# =======================
# Load Environment Variables from .env File
# =======================
load_dotenv() # Explicitly load the .env file
# Set OpenAI API key
openai.api_key = os.getenv("OPENAI_API_KEY")
# Debugging: Check if API key is loaded
if not openai.api_key or not openai.api_key.startswith("sk-"):
st.error("OpenAI API key is not set or is invalid. Please check the `.env` file or your environment variable setup.")
st.stop()
# =======================
# Streamlit Page Config
# =======================
st.set_page_config(
page_title="AI-Powered Skin Cancer Detection",
page_icon="🩺",
layout="wide",
initial_sidebar_state="expanded"
)
# =======================
# Load Skin Cancer Model (PyTorch)
# =======================
@st.cache_resource
def load_model():
"""
Load the pre-trained skin cancer classification model using PyTorch.
Use the AutoModelForImageClassification and AutoFeatureExtractor for explicit local caching.
"""
try:
extractor = AutoFeatureExtractor.from_pretrained("Anwarkh1/Skin_Cancer-Image_Classification")
model = AutoModelForImageClassification.from_pretrained("Anwarkh1/Skin_Cancer-Image_Classification")
return pipeline("image-classification", model=model, feature_extractor=extractor, framework="pt")
except Exception as e:
st.error(f"Error loading the model: {e}")
return None
model = load_model()
# =======================
# Generate OpenAI Explanation
# =======================
def generate_openai_explanation(label, confidence):
"""
Generate a detailed explanation for the classification result using OpenAI's GPT model.
"""
prompt = (
f"The AI model has classified an image of a skin lesion as **{label}** with a confidence of **{confidence:.2%}**.\n"
f"Explain what this classification means, including potential characteristics of this lesion type, "
f"what steps a patient should take next, and how the AI might have arrived at this conclusion. "
f"Use language that is easy for a non-medical audience to understand."
)
try:
response = openai.Completion.create(
model="text-davinci-003", # Replace with "gpt-4" if available
prompt=prompt,
max_tokens=300,
temperature=0.7
)
return response.choices[0].text.strip()
except Exception as e:
return f"Error generating explanation: {e}"
# =======================
# Streamlit App Title and Sidebar
# =======================
st.title("πŸ” AI-Powered Skin Cancer Classification and Explanation")
st.write("Upload an image of a skin lesion, and the AI model will classify it and provide a detailed explanation.")
st.sidebar.info("""
**AI Cancer Detection Platform**
This application uses AI to classify skin lesions and generate detailed explanations for informational purposes.
It is not intended for medical diagnosis. Always consult a healthcare professional for medical advice.
""")
# =======================
# File Upload and Prediction
# =======================
uploaded_image = st.file_uploader("Upload a skin lesion image (PNG, JPG, JPEG)", type=["png", "jpg", "jpeg"])
if uploaded_image:
# Display uploaded image
image = Image.open(uploaded_image).convert("RGB")
st.image(image, caption="Uploaded Image", use_column_width=True)
# Perform classification
if model is None:
st.error("Model could not be loaded. Please try again later.")
else:
with st.spinner("Classifying the image..."):
try:
results = model(image)
label = results[0]['label']
confidence = results[0]['score']
# Display prediction results
st.markdown(f"### Prediction: **{label}**")
st.markdown(f"### Confidence: **{confidence:.2%}**")
# Provide confidence-based insights
if confidence >= 0.8:
st.success("High confidence in the prediction.")
elif confidence >= 0.5:
st.warning("Moderate confidence in the prediction. Consider additional verification.")
else:
st.error("Low confidence in the prediction. Results should be interpreted with caution.")
# Generate explanation
with st.spinner("Generating a detailed explanation..."):
explanation = generate_openai_explanation(label, confidence)
st.markdown("### Explanation")
st.write(explanation)
except Exception as e:
st.error(f"Error during classification: {e}")