aaravlovescodes commited on
Commit
fe99911
·
verified ·
1 Parent(s): 221b833

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -58
app.py CHANGED
@@ -1,37 +1,53 @@
1
  import streamlit as st
2
  import tensorflow as tf
3
- import random
4
- from PIL import Image
5
- from tensorflow import keras
6
  import numpy as np
 
 
7
  import os
8
 
 
9
  import warnings
10
-
11
  warnings.filterwarnings("ignore")
12
- os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
13
 
 
14
  st.set_page_config(
15
  page_title="ChestAI - Pneumonia Detection",
16
  page_icon="🫁",
17
  initial_sidebar_state="auto",
18
  )
19
 
 
20
  hide_streamlit_style = """
21
- <style>
22
- #MainMenu {visibility: hidden;}
23
- footer {visibility: hidden;}
24
- </style>
25
  """
26
  st.markdown(hide_streamlit_style, unsafe_allow_html=True)
27
 
 
 
 
 
 
 
28
 
29
- def prediction_cls(prediction):
30
- for key, clss in class_names.items():
31
- if np.argmax(prediction) == clss:
32
- return key
 
 
33
 
 
 
 
 
 
 
 
34
 
 
35
  with st.sidebar:
36
  st.title("ChestAI")
37
  st.markdown("""
@@ -46,68 +62,43 @@ with st.sidebar:
46
  ### Note
47
  This tool is for educational purposes only. Always consult healthcare professionals for medical advice.
48
  """)
49
- st.set_option("deprecation.showfileUploaderEncoding", False)
50
-
51
 
52
- @st.cache_resource(show_spinner=False)
53
- def load_model():
54
- try:
55
- from huggingface_hub import hf_hub_download
56
- from keras.layers import TFSMLayer
57
-
58
- # Download the model files directly
59
- model_path = hf_hub_download(repo_id="ryefoxlime/PneumoniaDetection", filename="saved_model.pb")
60
-
61
- # Use TFSMLayer to load the model
62
- model = TFSMLayer(model_path, call_endpoint='serving_default')
63
- return model
64
- except Exception as e:
65
- st.error(f"Error loading model: {str(e)}")
66
- return None
67
-
68
-
69
- with st.spinner("Model is being loaded..."):
70
- model = load_model()
71
-
72
- if model is None:
73
- st.error("Failed to load model. Please try again.")
74
- st.stop()
75
-
76
- file = st.file_uploader(" ", type=["jpg", "png"])
77
 
 
 
78
 
79
  def import_and_predict(image_data, model):
80
- img_array = keras.preprocessing.image.img_to_array(image_data)
81
- img_array = np.expand_dims(img_array, axis=0)
82
- img_array = img_array/255
83
 
84
- predictions = model.predict(img_array)
 
85
  return predictions
86
 
 
 
87
 
88
  if file is None:
89
  st.text("Please upload an image file")
90
  else:
91
  try:
92
- image = keras.preprocessing.image.load_img(file, target_size=(224, 224), color_mode='rgb')
93
  st.image(image, caption="Uploaded Image.", use_column_width=True)
94
- predictions = import_and_predict(image, model)
95
 
96
- class_names = [
97
- "Normal",
98
- "PNEUMONIA",
99
- ]
100
 
101
- confidence = float(max(predictions[0]) * 100)
102
- prediction_label = class_names[np.argmax(predictions)]
103
-
104
  st.info(f"Confidence: {confidence:.2f}%")
105
 
106
- if prediction_label == "Normal":
107
  st.balloons()
108
- st.success(f"Result: {prediction_label}")
109
  else:
110
- st.warning(f"Result: {prediction_label}")
111
-
112
  except Exception as e:
113
- st.error(f"Error processing image: {str(e)}")
 
1
  import streamlit as st
2
  import tensorflow as tf
 
 
 
3
  import numpy as np
4
+ from PIL import Image
5
+ from huggingface_hub import hf_hub_download
6
  import os
7
 
8
+ # Suppress warnings
9
  import warnings
 
10
  warnings.filterwarnings("ignore")
 
11
 
12
+ # Set page configuration
13
  st.set_page_config(
14
  page_title="ChestAI - Pneumonia Detection",
15
  page_icon="🫁",
16
  initial_sidebar_state="auto",
17
  )
18
 
19
+ # Hide Streamlit style
20
  hide_streamlit_style = """
21
+ <style>
22
+ #MainMenu {visibility: hidden;}
23
+ footer {visibility: hidden;}
24
+ </style>
25
  """
26
  st.markdown(hide_streamlit_style, unsafe_allow_html=True)
27
 
28
+ # Function to load the model
29
+ @st.cache_resource(show_spinner=False)
30
+ def load_model():
31
+ try:
32
+ # Download the model directory
33
+ model_dir = hf_hub_download(repo_id="ryefoxlime/PneumoniaDetection", repo_type="model", library="tf", cache_dir="/home/user/.cache/huggingface/hub")
34
 
35
+ # Load the model using tf.saved_model.load
36
+ model = tf.saved_model.load(model_dir)
37
+ return model
38
+ except Exception as e:
39
+ st.error(f"Error loading model: {str(e)}")
40
+ return None
41
 
42
+ # Load the model
43
+ with st.spinner("Model is being loaded..."):
44
+ model = load_model()
45
+
46
+ if model is None:
47
+ st.error("Failed to load model. Please try again.")
48
+ st.stop()
49
 
50
+ # Sidebar for app information
51
  with st.sidebar:
52
  st.title("ChestAI")
53
  st.markdown("""
 
62
  ### Note
63
  This tool is for educational purposes only. Always consult healthcare professionals for medical advice.
64
  """)
 
 
65
 
66
+ st.set_option("deprecation.showfileUploaderEncoding", False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
+ # File uploader for image input
69
+ file = st.file_uploader("Upload a chest X-ray image", type=["jpg", "png"])
70
 
71
  def import_and_predict(image_data, model):
72
+ img_array = tf.keras.preprocessing.image.img_to_array(image_data)
73
+ img_array = np.expand_dims(img_array, axis=0) # Add batch dimension
74
+ img_array = img_array / 255.0 # Normalize the image
75
 
76
+ # Perform prediction
77
+ predictions = model(img_array) # Call the model for prediction
78
  return predictions
79
 
80
+ # Class names for prediction results
81
+ class_names = ["Normal", "PNEUMONIA"]
82
 
83
  if file is None:
84
  st.text("Please upload an image file")
85
  else:
86
  try:
87
+ image = tf.keras.preprocessing.image.load_img(file, target_size=(224, 224), color_mode='rgb')
88
  st.image(image, caption="Uploaded Image.", use_column_width=True)
 
89
 
90
+ predictions = import_and_predict(image, model)
91
+ predicted_class = np.argmax(predictions) # Get the index of the highest prediction
92
+ confidence = float(predictions[0][predicted_class] * 100) # Confidence percentage
 
93
 
94
+ # Display the results
 
 
95
  st.info(f"Confidence: {confidence:.2f}%")
96
 
97
+ if class_names[predicted_class] == "Normal":
98
  st.balloons()
99
+ st.success(f"Result: {class_names[predicted_class]}")
100
  else:
101
+ st.warning(f"Result: {class_names[predicted_class]}")
102
+
103
  except Exception as e:
104
+ st.error(f"Error processing image: {str(e)}")