File size: 3,262 Bytes
d167cec
 
 
fe99911
 
fa9e49b
d167cec
fe99911
d167cec
 
 
fe99911
d167cec
6d6b205
 
 
 
d167cec
fe99911
d167cec
fe99911
 
 
 
d167cec
 
 
fe99911
 
 
 
 
 
281deee
fe99911
 
 
 
 
 
281deee
fe99911
 
 
 
 
 
 
d167cec
fe99911
d167cec
6d6b205
ec84b54
6d6b205
 
65847f9
6d6b205
 
 
 
65847f9
6d6b205
 
 
281deee
fe99911
d167cec
fe99911
 
281deee
d167cec
fe99911
 
 
d167cec
fe99911
 
d167cec
 
fe99911
 
281deee
d167cec
 
 
3481d08
fe99911
3481d08
281deee
fe99911
 
 
3481d08
fe99911
3481d08
281deee
fe99911
3481d08
fe99911
3481d08
fe99911
 
3481d08
fe99911
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import streamlit as st
import tensorflow as tf
import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download
import os

# Suppress warnings
import warnings
warnings.filterwarnings("ignore")

# Set page configuration
st.set_page_config(
    page_title="ChestAI - Pneumonia Detection",
    page_icon="🫁",
    initial_sidebar_state="auto",
)

# Hide Streamlit style
hide_streamlit_style = """
<style>
#MainMenu {visibility: hidden;}
footer {visibility: hidden;}
</style>
"""
st.markdown(hide_streamlit_style, unsafe_allow_html=True)

# Function to load the model
@st.cache_resource(show_spinner=False)
def load_model():
    try:
        # Download the model directory
        model_dir = hf_hub_download(repo_id="ryefoxlime/PneumoniaDetection", repo_type="model", library="tf", cache_dir="/home/user/.cache/huggingface/hub")

        # Load the model using tf.saved_model.load
        model = tf.saved_model.load(model_dir)
        return model
    except Exception as e:
        st.error(f"Error loading model: {str(e)}")
        return None

# Load the model
with st.spinner("Model is being loaded..."):
    model = load_model()

if model is None:
    st.error("Failed to load model. Please try again.")
    st.stop()

# Sidebar for app information
with st.sidebar:
    st.title("ChestAI")
    st.markdown("""
        ### About
        ChestAI uses advanced deep learning to detect pneumonia in chest X-rays.
                
        ### How to use
        1. Upload a chest X-ray image (JPG/PNG)
        2. Wait for the analysis
        3. View the results and confidence score
                
        ### Note
        This tool is for educational purposes only. Always consult healthcare professionals for medical advice.
    """)

st.set_option("deprecation.showfileUploaderEncoding", False)

# File uploader for image input
file = st.file_uploader("Upload a chest X-ray image", type=["jpg", "png"])

def import_and_predict(image_data, model):
    img_array = tf.keras.preprocessing.image.img_to_array(image_data)
    img_array = np.expand_dims(img_array, axis=0)  # Add batch dimension
    img_array = img_array / 255.0  # Normalize the image

    # Perform prediction
    predictions = model(img_array)  # Call the model for prediction
    return predictions

# Class names for prediction results
class_names = ["Normal", "PNEUMONIA"]

if file is None:
    st.text("Please upload an image file")
else:
    try:
        image = tf.keras.preprocessing.image.load_img(file, target_size=(224, 224), color_mode='rgb')
        st.image(image, caption="Uploaded Image.", use_column_width=True)
        
        predictions = import_and_predict(image, model)
        predicted_class = np.argmax(predictions)  # Get the index of the highest prediction
        confidence = float(predictions[0][predicted_class] * 100)  # Confidence percentage

        # Display the results
        st.info(f"Confidence: {confidence:.2f}%")
        
        if class_names[predicted_class] == "Normal":
            st.balloons()
            st.success(f"Result: {class_names[predicted_class]}")
        else:
            st.warning(f"Result: {class_names[predicted_class]}")

    except Exception as e:
        st.error(f"Error processing image: {str(e)}")