Plsek commited on
Commit
0755e66
·
1 Parent(s): 85b6907

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -52
app.py CHANGED
@@ -24,8 +24,6 @@ from sklearn.cluster import DBSCAN
24
  # Streamlit
25
  import streamlit as st
26
  st.set_option('deprecation.showPyplotGlobalUse', False)
27
- st.set_page_config(page_title="Cavity Detection Tool", layout="wide")
28
- # st.title("Cavity Detection Tool")
29
 
30
  # Define function to plot the uploaded image
31
  def plot_image(image, scale):
@@ -58,7 +56,7 @@ def plot_decomposed(decomposed):
58
  plt.axis('off')
59
  with colC: st.pyplot()
60
 
61
- # Cut input image and rebin it to 128x128 pixels
62
  def cut(data0, wcs0, scale=1):
63
  shape = data0.shape[0]
64
  x0 = shape / 2
@@ -81,7 +79,7 @@ def cut(data0, wcs0, scale=1):
81
 
82
  return data, wcs
83
 
84
-
85
  @st.cache
86
  def cut_n_predict(data, wcs, scale):
87
  data, wcs = cut(data, wcs, scale=scale)
@@ -96,13 +94,13 @@ def cut_n_predict(data, wcs, scale):
96
 
97
  return y_pred, wcs
98
 
99
- # Define function to decomposed prediction into cavities
100
  @st.cache
101
  def decompose_cavity(pred, th2=0.7, amin=10):
102
  X, Y = pred.nonzero()
103
  data = np.array([X,Y]).reshape(2, -1)
104
 
105
- # DBSCAN CLUSTERING ALGORITHM
106
  try: clusters = DBSCAN(eps=1.0, min_samples=3).fit(data.T).labels_
107
  except: clusters = []
108
 
@@ -115,14 +113,16 @@ def decompose_cavity(pred, th2=0.7, amin=10):
115
  xi, yi = X[b], Y[b]
116
  img[xi, yi] = pred[xi, yi]
117
 
118
- # THRESHOLDING #2
119
  if not (img > th2).any(): continue
120
 
121
- # MINIMAL AREA
122
  if np.sum(img) <= amin: continue
123
 
124
  cavities.append(img)
125
 
 
 
126
  ccd = CCDData(pred, unit="adu", wcs=wcs)
127
  ccd.write(f"predictions/predicted.fits", overwrite=True)
128
  image_decomposed = np.zeros((128,128))
@@ -133,24 +133,14 @@ def decompose_cavity(pred, th2=0.7, amin=10):
133
 
134
  return image_decomposed
135
 
136
- # @st.cache
137
- # def zip_predictions():
138
- # shutil.make_archive("predictions.zip", 'zip', "predictions")
139
- # with open('predictions.zip', 'rb') as f:
140
- # res = f.read()
141
- # return res
142
-
143
  bordersize = 0.6
144
  _, col, _ = st.columns([bordersize, 3, bordersize])
145
 
146
- # if os.path.exists("pred.npy"): os.system("rm pred.npy")
147
- # os.system("rm -r predictions")
148
- # os.system("rm predictions.zip Views")
149
- os.system("mkdir -p predictions")
150
-
151
  with col:
152
- st.markdown("# Cavity Detection Tool")
153
-
154
  st.markdown("Cavity Detection Tool (CADET) is a machine learning pipeline trained to detect X-ray cavities from noisy Chandra images of early-type galaxies.")
155
  st.markdown("To use this tool: upload your image, select the scale of interest, and make a prediction!")
156
  st.markdown("Input images should be centred at the centre of the galaxy and point sources should be filled with surrounding background ([dmfilth](https://cxc.cfa.harvard.edu/ciao/ahelp/dmfilth.html)).")
@@ -170,61 +160,46 @@ if uploaded_file is not None:
170
  col1.subheader("Input image")
171
  col3.subheader("Prediction")
172
  col5.subheader("Decomposed")
173
- col6.subheader("")
174
 
175
  with col1:
176
-
177
  # st.markdown("""<style>[data-baseweb="select"] {margin-top: -26px;}</style>""", unsafe_allow_html=True)
178
  max_scale = int(data.shape[0] // 128)
179
  scale = st.selectbox('Scale:',[f"{(i+1)*128}x{(i+1)*128}" for i in range(max_scale)], label_visibility="hidden")
180
  scale = int(scale.split("x")[0]) // 128
181
-
182
- with col3:
183
- detect = st.button('Detect')
184
 
185
- with col5:
186
- decompose = st.button('Decompose')
 
 
 
 
 
 
 
 
187
 
188
  # Make two columns for plots
189
  _, colA, colB, colC, _ = st.columns([bordersize,1,1,1,bordersize])
190
 
191
  image = np.log10(data+1)
192
  plot_image(image, scale)
193
-
194
-
195
- with col4:
196
- st.markdown("""<style>[data-baseweb="select"] {margin-top: -36px;}</style>""", unsafe_allow_html=True)
197
- threshold = st.slider("", 0.0, 1.0, 0.0, 0.05, label_visibility="hidden")
198
 
199
  if detect or threshold:
200
  y_pred, wcs = cut_n_predict(data, wcs, scale)
201
 
202
- # np.save("pred.npy", y_pred)
203
-
204
- # try: y_pred = np.load("pred.npy")
205
  # except: y_pred = np.zeros((128,128))
206
- try: _ = y_pred
207
- except: y_pred = np.zeros((128,128))
208
  y_pred_th = np.where(y_pred > threshold, y_pred, 0)
209
- # np.save("thresh.npy", y_pred)
210
 
211
  plot_prediction(y_pred_th)
212
 
213
  if decompose:
214
- # y_pred = np.load("thresh.npy")
215
 
216
  image_decomposed = decompose_cavity(y_pred_th)
217
 
218
- # ccd = CCDData(y_pred, unit="adu", wcs=wcs)
219
- # ccd.write(f"predictions/predicted.fits", overwrite=True)
220
- # image_decomposed = np.zeros((128,128))
221
- # for i, cav in enumerate(cavs):
222
- # ccd = CCDData(cav, unit="adu", wcs=wcs)
223
- # ccd.write(f"predictions/predicted_{i+1}.fits", overwrite=True)
224
- # image_decomposed += (i+1) * np.where(cav > 0, 1, 0)
225
-
226
- try: _ = image_decomposed
227
- except: image_decomposed = np.zeros((128,128))
228
  plot_decomposed(image_decomposed)
229
 
230
  with col6:
@@ -235,4 +210,4 @@ if uploaded_file is not None:
235
  st.markdown("")
236
  # st.markdown("""<style>[data-baseweb="select"] {margin-top: 16px;}</style>""", unsafe_allow_html=True)
237
  fname = uploaded_file.name.strip(".fits")
238
- download = st.download_button(label="Download", data=res, file_name=f'pred_{int(scale*128)}.zip', mime="application/octet-stream")
 
24
  # Streamlit
25
  import streamlit as st
26
  st.set_option('deprecation.showPyplotGlobalUse', False)
 
 
27
 
28
  # Define function to plot the uploaded image
29
  def plot_image(image, scale):
 
56
  plt.axis('off')
57
  with colC: st.pyplot()
58
 
59
+ # Define function to cut input image and rebin it to 128x128 pixels
60
  def cut(data0, wcs0, scale=1):
61
  shape = data0.shape[0]
62
  x0 = shape / 2
 
79
 
80
  return data, wcs
81
 
82
+ # Define function to apply cutting and produce a prediction
83
  @st.cache
84
  def cut_n_predict(data, wcs, scale):
85
  data, wcs = cut(data, wcs, scale=scale)
 
94
 
95
  return y_pred, wcs
96
 
97
+ # Define function to decompose prediction into individual cavities
98
  @st.cache
99
  def decompose_cavity(pred, th2=0.7, amin=10):
100
  X, Y = pred.nonzero()
101
  data = np.array([X,Y]).reshape(2, -1)
102
 
103
+ # DBSCAN clustering
104
  try: clusters = DBSCAN(eps=1.0, min_samples=3).fit(data.T).labels_
105
  except: clusters = []
106
 
 
113
  xi, yi = X[b], Y[b]
114
  img[xi, yi] = pred[xi, yi]
115
 
116
+ # Thresholding #2
117
  if not (img > th2).any(): continue
118
 
119
+ # Minimal area
120
  if np.sum(img) <= amin: continue
121
 
122
  cavities.append(img)
123
 
124
+ # Save raw and decomposed predictions to predictions folder
125
+ os.system("mkdir -p predictions")
126
  ccd = CCDData(pred, unit="adu", wcs=wcs)
127
  ccd.write(f"predictions/predicted.fits", overwrite=True)
128
  image_decomposed = np.zeros((128,128))
 
133
 
134
  return image_decomposed
135
 
136
+ # Use wide layout and create columns
137
+ st.set_page_config(page_title="Cavity Detection Tool", layout="wide")
 
 
 
 
 
138
  bordersize = 0.6
139
  _, col, _ = st.columns([bordersize, 3, bordersize])
140
 
 
 
 
 
 
141
  with col:
142
+ # Create heading and description
143
+ st.markdown("# Cavity Detection Tool")
144
  st.markdown("Cavity Detection Tool (CADET) is a machine learning pipeline trained to detect X-ray cavities from noisy Chandra images of early-type galaxies.")
145
  st.markdown("To use this tool: upload your image, select the scale of interest, and make a prediction!")
146
  st.markdown("Input images should be centred at the centre of the galaxy and point sources should be filled with surrounding background ([dmfilth](https://cxc.cfa.harvard.edu/ciao/ahelp/dmfilth.html)).")
 
160
  col1.subheader("Input image")
161
  col3.subheader("Prediction")
162
  col5.subheader("Decomposed")
163
+ # col6.subheader("")
164
 
165
  with col1:
 
166
  # st.markdown("""<style>[data-baseweb="select"] {margin-top: -26px;}</style>""", unsafe_allow_html=True)
167
  max_scale = int(data.shape[0] // 128)
168
  scale = st.selectbox('Scale:',[f"{(i+1)*128}x{(i+1)*128}" for i in range(max_scale)], label_visibility="hidden")
169
  scale = int(scale.split("x")[0]) // 128
 
 
 
170
 
171
+ # Detect button
172
+ with col3: detect = st.button('Detect')
173
+
174
+ # Threshold slider
175
+ with col4:
176
+ # st.markdown("""<style>[data-baseweb="select"] {margin-top: -36px;}</style>""", unsafe_allow_html=True)
177
+ threshold = st.slider("Threshold", 0.0, 1.0, 0.0, 0.05) #, label_visibility="hidden")
178
+
179
+ # Decompose button
180
+ with col5: decompose = st.button('Decompose')
181
 
182
  # Make two columns for plots
183
  _, colA, colB, colC, _ = st.columns([bordersize,1,1,1,bordersize])
184
 
185
  image = np.log10(data+1)
186
  plot_image(image, scale)
 
 
 
 
 
187
 
188
  if detect or threshold:
189
  y_pred, wcs = cut_n_predict(data, wcs, scale)
190
 
191
+ # try: _ = y_pred
 
 
192
  # except: y_pred = np.zeros((128,128))
 
 
193
  y_pred_th = np.where(y_pred > threshold, y_pred, 0)
 
194
 
195
  plot_prediction(y_pred_th)
196
 
197
  if decompose:
 
198
 
199
  image_decomposed = decompose_cavity(y_pred_th)
200
 
201
+ # try: _ = image_decomposed
202
+ # except: image_decomposed = np.zeros((128,128))
 
 
 
 
 
 
 
 
203
  plot_decomposed(image_decomposed)
204
 
205
  with col6:
 
210
  st.markdown("")
211
  # st.markdown("""<style>[data-baseweb="select"] {margin-top: 16px;}</style>""", unsafe_allow_html=True)
212
  fname = uploaded_file.name.strip(".fits")
213
+ download = st.download_button(label="Download", data=res, file_name=f'{fname}_{int(scale*128)}.zip', mime="application/octet-stream")