Plsek commited on
Commit
d083cdb
·
1 Parent(s): 7368767

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -75
app.py CHANGED
@@ -25,12 +25,6 @@ from sklearn.cluster import DBSCAN
25
  import streamlit as st
26
  st.set_option('deprecation.showPyplotGlobalUse', False)
27
 
28
- # Use wide layout and create columns
29
- st.set_page_config(page_title="Cavity Detection Tool", layout="wide")
30
- bordersize = 0.45
31
- _, col, _ = st.columns([bordersize, 3, bordersize])
32
-
33
-
34
  # Define function to plot the uploaded image
35
  def plot_image(image, scale):
36
  plt.figure(figsize=(4, 4))
@@ -87,7 +81,7 @@ def cut(data0, wcs0, scale=1):
87
  return data, wcs
88
 
89
  # Define function to apply cutting and produce a prediction
90
- # @st.cache
91
  def cut_n_predict(data, wcs, scale):
92
  data, wcs = cut(data, wcs, scale=scale)
93
  image = np.log10(data+1)
@@ -102,7 +96,7 @@ def cut_n_predict(data, wcs, scale):
102
  return y_pred, wcs
103
 
104
  # Define function to decompose prediction into individual cavities
105
- # @st.cache
106
  def decompose_cavity(pred, th2=0.7, amin=6):
107
  X, Y = pred.nonzero()
108
  data = np.array([X,Y]).reshape(2, -1)
@@ -141,17 +135,24 @@ def decompose_cavity(pred, th2=0.7, amin=6):
141
 
142
  return image_decomposed
143
 
144
- # @st.cache
145
  def load_file(fname):
146
  with fits.open(fname) as hdul:
147
  data = hdul[0].data
148
  wcs = WCS(hdul[0].header)
149
  return data, wcs
150
 
151
- # def change_scale():
152
- # del st.session_state["threshold"]
 
153
 
154
- # os.system("mkdir -p predictions")
 
 
 
 
 
 
155
 
156
  with col:
157
  # Create heading and description
@@ -161,12 +162,10 @@ with col:
161
  st.markdown("Input images should be in units of counts, centred at the galaxy center, and point sources should be filled with surrounding background ([dmfilth](https://cxc.cfa.harvard.edu/ciao/ahelp/dmfilth.html)).")
162
  st.markdown("If you use this tool for your research, please cite [Plšek et al. 2023](https://arxiv.org/abs/2304.05457)")
163
 
164
- uploaded_file = st.file_uploader("Choose a FITS file", type=['fits'])
165
-
166
  # _, col_1, col_2, col_3, _ = st.columns([bordersize, 2.0, 0.5, 0.5, bordersize])
167
 
168
  # with col:
169
- # uploaded_file = st.file_uploader("Choose a FITS file", type=['fits'])
170
 
171
  # with col_2:
172
  # st.markdown("### Examples")
@@ -182,71 +181,68 @@ with col:
182
  # uploaded_file = "NGC5813_example.fits"
183
 
184
  # If file is uploaded, read in the data and plot it
185
- # if uploaded_file is not None:
186
- # data, wcs = load_file(uploaded_file)
187
-
188
- # if "data" not in locals():
189
- # data = np.zeros((128,128))
190
-
191
  if uploaded_file is not None:
192
  data, wcs = load_file(uploaded_file)
 
 
 
193
 
194
- # Make six columns for buttons
195
- _, col1, col2, col3, col4, col5, col6, _ = st.columns([bordersize,0.5,0.5,0.5,0.5,0.5,0.5,bordersize])
196
- col1.subheader("Input image")
197
- col3.subheader("Prediction")
198
- col5.subheader("Decomposed")
199
- col6.subheader("")
200
-
201
- with col1:
202
- st.markdown("""<style>[data-baseweb="select"] {margin-top: -46px;}</style>""", unsafe_allow_html=True)
203
- max_scale = int(data.shape[0] // 128)
204
- scale = st.selectbox('Scale:',[f"{(i+1)*128}x{(i+1)*128}" for i in range(max_scale)], label_visibility="hidden") #, on_change=change_scale)
205
- scale = int(scale.split("x")[0]) // 128
206
-
207
- # Detect button
208
- with col3: detect = st.button('Detect', key="detect")
209
-
210
- # Threshold slider
211
- with col4:
212
- st.markdown("")
213
- # st.markdown("""<style>[data-baseweb="select"] {margin-top: -36px;}</style>""", unsafe_allow_html=True)
214
- threshold = st.slider("Threshold", 0.0, 1.0, 0.0, 0.05, key="threshold") #, label_visibility="hidden")
215
-
216
- # Decompose button
217
- # with col5: decompose = st.button('Decompose', key="decompose")
218
-
219
- # Make two columns for plots
220
- _, colA, colB, colC, _ = st.columns([bordersize,1,1,1,bordersize])
221
 
222
- image = np.log10(data+1)
223
- plot_image(image, scale)
224
 
225
- if detect or threshold:
226
- # if st.session_state.get("detect", True):
227
- y_pred, wcs = cut_n_predict(data, wcs, scale)
228
-
229
- y_pred_th = np.where(y_pred > threshold, y_pred, 0)
230
-
231
- plot_prediction(y_pred_th)
 
 
232
 
233
- # if decompose or st.session_state.get("download", False):
234
- # image_decomposed = decompose_cavity(y_pred_th)
 
 
 
 
 
 
 
 
 
 
 
235
 
236
- # plot_decomposed(image_decomposed)
237
 
238
- # with col6:
239
- # st.markdown("<br style='margin:4px 0'>", unsafe_allow_html=True)
240
- # # st.markdown("""<style>[data-baseweb="select"] {margin-top: 16px;}</style>""", unsafe_allow_html=True)
241
- # fname = uploaded_file.name.strip(".fits")
242
-
243
- # # if st.session_state.get("download", False):
244
-
245
- # shutil.make_archive("predictions", 'zip', "predictions")
246
- # with open('predictions.zip', 'rb') as f:
247
- # res = f.read()
248
-
249
- # download = st.download_button(label="Download", data=res, key="download",
250
- # file_name=f'{fname}_{int(scale*128)}.zip',
251
- # # disabled=st.session_state.get("disabled", True),
252
- # mime="application/octet-stream")
 
25
  import streamlit as st
26
  st.set_option('deprecation.showPyplotGlobalUse', False)
27
 
 
 
 
 
 
 
28
  # Define function to plot the uploaded image
29
  def plot_image(image, scale):
30
  plt.figure(figsize=(4, 4))
 
81
  return data, wcs
82
 
83
  # Define function to apply cutting and produce a prediction
84
+ @st.cache
85
  def cut_n_predict(data, wcs, scale):
86
  data, wcs = cut(data, wcs, scale=scale)
87
  image = np.log10(data+1)
 
96
  return y_pred, wcs
97
 
98
  # Define function to decompose prediction into individual cavities
99
+ @st.cache
100
  def decompose_cavity(pred, th2=0.7, amin=6):
101
  X, Y = pred.nonzero()
102
  data = np.array([X,Y]).reshape(2, -1)
 
135
 
136
  return image_decomposed
137
 
138
+ @st.cache
139
  def load_file(fname):
140
  with fits.open(fname) as hdul:
141
  data = hdul[0].data
142
  wcs = WCS(hdul[0].header)
143
  return data, wcs
144
 
145
+ def change_scale():
146
+ del st.session_state["threshold"]
147
+
148
 
149
+
150
+ # Use wide layout and create columns
151
+ st.set_page_config(page_title="Cavity Detection Tool", layout="wide")
152
+ bordersize = 0.45
153
+ _, col, _ = st.columns([bordersize, 3, bordersize])
154
+
155
+ os.system("mkdir -p predictions")
156
 
157
  with col:
158
  # Create heading and description
 
162
  st.markdown("Input images should be in units of counts, centred at the galaxy center, and point sources should be filled with surrounding background ([dmfilth](https://cxc.cfa.harvard.edu/ciao/ahelp/dmfilth.html)).")
163
  st.markdown("If you use this tool for your research, please cite [Plšek et al. 2023](https://arxiv.org/abs/2304.05457)")
164
 
 
 
165
  # _, col_1, col_2, col_3, _ = st.columns([bordersize, 2.0, 0.5, 0.5, bordersize])
166
 
167
  # with col:
168
+ uploaded_file = st.file_uploader("Choose a FITS file", type=['fits'])
169
 
170
  # with col_2:
171
  # st.markdown("### Examples")
 
181
  # uploaded_file = "NGC5813_example.fits"
182
 
183
  # If file is uploaded, read in the data and plot it
 
 
 
 
 
 
184
  if uploaded_file is not None:
185
  data, wcs = load_file(uploaded_file)
186
+
187
+ if "data" not in locals():
188
+ data = np.zeros((128,128))
189
 
190
+ # Make six columns for buttons
191
+ _, col1, col2, col3, col4, col5, col6, _ = st.columns([bordersize,0.5,0.5,0.5,0.5,0.5,0.5,bordersize])
192
+ col1.subheader("Input image")
193
+ col3.subheader("Prediction")
194
+ col5.subheader("Decomposed")
195
+ col6.subheader("")
196
+
197
+ with col1:
198
+ st.markdown("""<style>[data-baseweb="select"] {margin-top: -46px;}</style>""", unsafe_allow_html=True)
199
+ max_scale = int(data.shape[0] // 128)
200
+ scale = st.selectbox('Scale:',[f"{(i+1)*128}x{(i+1)*128}" for i in range(max_scale)], label_visibility="hidden", on_change=change_scale)
201
+ scale = int(scale.split("x")[0]) // 128
202
+
203
+ # Detect button
204
+ with col3: detect = st.button('Detect', key="detect")
205
+
206
+ # Threshold slider
207
+ with col4:
208
+ st.markdown("")
209
+ # st.markdown("""<style>[data-baseweb="select"] {margin-top: -36px;}</style>""", unsafe_allow_html=True)
210
+ threshold = st.slider("Threshold", 0.0, 1.0, 0.0, 0.05, key="threshold") #, label_visibility="hidden")
 
 
 
 
 
 
211
 
212
+ # Decompose button
213
+ with col5: decompose = st.button('Decompose', key="decompose")
214
 
215
+ # Make two columns for plots
216
+ _, colA, colB, colC, _ = st.columns([bordersize,1,1,1,bordersize])
217
+
218
+ image = np.log10(data+1)
219
+ plot_image(image, scale)
220
+
221
+ if detect or threshold:
222
+ # if st.session_state.get("detect", True):
223
+ y_pred, wcs = cut_n_predict(data, wcs, scale)
224
 
225
+ y_pred_th = np.where(y_pred > threshold, y_pred, 0)
226
+
227
+ plot_prediction(y_pred_th)
228
+
229
+ if decompose or st.session_state.get("download", False):
230
+ image_decomposed = decompose_cavity(y_pred_th)
231
+
232
+ plot_decomposed(image_decomposed)
233
+
234
+ with col6:
235
+ st.markdown("<br style='margin:4px 0'>", unsafe_allow_html=True)
236
+ # st.markdown("""<style>[data-baseweb="select"] {margin-top: 16px;}</style>""", unsafe_allow_html=True)
237
+ fname = uploaded_file.name.strip(".fits")
238
 
239
+ # if st.session_state.get("download", False):
240
 
241
+ shutil.make_archive("predictions", 'zip', "predictions")
242
+ with open('predictions.zip', 'rb') as f:
243
+ res = f.read()
244
+
245
+ download = st.download_button(label="Download", data=res, key="download",
246
+ file_name=f'{fname}_{int(scale*128)}.zip',
247
+ # disabled=st.session_state.get("disabled", True),
248
+ mime="application/octet-stream")