Plsek commited on
Commit
58946f4
·
1 Parent(s): 1b84f6e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -55
app.py CHANGED
@@ -32,6 +32,7 @@ def plot_image(image, scale):
32
  plt.imshow(image, origin="lower")
33
  plt.gca().add_patch(Rectangle((x0-0.5, x0-0.5), scale*128, scale*128, linewidth=1, edgecolor='w', facecolor='none'))
34
  plt.axis('off')
 
35
  with colA: st.pyplot()
36
 
37
  # Define function to plot the prediction
@@ -183,63 +184,65 @@ with col:
183
  if uploaded_file is not None:
184
  data, wcs = load_file(uploaded_file)
185
 
186
- if "data" in locals():
187
- # Make six columns for buttons
188
- _, col1, col2, col3, col4, col5, col6, _ = st.columns([bordersize,0.5,0.5,0.5,0.5,0.5,0.5,bordersize])
189
- col1.subheader("Input image")
190
- col3.subheader("Prediction")
191
- col5.subheader("Decomposed")
192
- col6.subheader("")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
- with col1:
195
- st.markdown("""<style>[data-baseweb="select"] {margin-top: -46px;}</style>""", unsafe_allow_html=True)
196
- max_scale = int(data.shape[0] // 128)
197
- scale = st.selectbox('Scale:',[f"{(i+1)*128}x{(i+1)*128}" for i in range(max_scale)], label_visibility="hidden", on_change=change_scale)
198
- scale = int(scale.split("x")[0]) // 128
199
-
200
- # Detect button
201
- with col3: detect = st.button('Detect', key="detect")
202
-
203
- # Threshold slider
204
- with col4:
205
- st.markdown("")
206
- # st.markdown("""<style>[data-baseweb="select"] {margin-top: -36px;}</style>""", unsafe_allow_html=True)
207
- threshold = st.slider("Threshold", 0.0, 1.0, 0.0, 0.05, key="threshold") #, label_visibility="hidden")
208
-
209
- # Decompose button
210
- with col5: decompose = st.button('Decompose', key="decompose")
211
-
212
- # Make two columns for plots
213
- _, colA, colB, colC, _ = st.columns([bordersize,1,1,1,bordersize])
214
-
215
- image = np.log10(data+1)
216
- plot_image(image, scale)
217
 
218
- if detect or threshold:
219
- # if st.session_state.get("detect", True):
220
- y_pred, wcs = cut_n_predict(data, wcs, scale)
221
-
222
- y_pred_th = np.where(y_pred > threshold, y_pred, 0)
223
-
224
- plot_prediction(y_pred_th)
225
 
226
- if decompose or st.session_state.get("download", False):
227
- image_decomposed = decompose_cavity(y_pred_th)
 
228
 
229
- plot_decomposed(image_decomposed)
 
 
230
 
231
- with col6:
232
- st.markdown("<br style='margin:4px 0'>", unsafe_allow_html=True)
233
- # st.markdown("""<style>[data-baseweb="select"] {margin-top: 16px;}</style>""", unsafe_allow_html=True)
234
- fname = uploaded_file.name.strip(".fits")
235
-
236
- # if st.session_state.get("download", False):
237
-
238
- shutil.make_archive("predictions", 'zip', "predictions")
239
- with open('predictions.zip', 'rb') as f:
240
- res = f.read()
241
-
242
- download = st.download_button(label="Download", data=res, key="download",
243
- file_name=f'{fname}_{int(scale*128)}.zip',
244
- # disabled=st.session_state.get("disabled", True),
245
- mime="application/octet-stream")
 
 
 
 
 
 
32
  plt.imshow(image, origin="lower")
33
  plt.gca().add_patch(Rectangle((x0-0.5, x0-0.5), scale*128, scale*128, linewidth=1, edgecolor='w', facecolor='none'))
34
  plt.axis('off')
35
+ plt.tight_layout()
36
  with colA: st.pyplot()
37
 
38
  # Define function to plot the prediction
 
184
  if uploaded_file is not None:
185
  data, wcs = load_file(uploaded_file)
186
 
187
+ if "data" not in locals():
188
+ data = np.zeros((128,128))
189
+
190
+ # Make six columns for buttons
191
+ _, col1, col2, col3, col4, col5, col6, _ = st.columns([bordersize,0.5,0.5,0.5,0.5,0.5,0.5,bordersize])
192
+ col1.subheader("Input image")
193
+ col3.subheader("Prediction")
194
+ col5.subheader("Decomposed")
195
+ col6.subheader("")
196
+
197
+ with col1:
198
+ st.markdown("""<style>[data-baseweb="select"] {margin-top: -46px;}</style>""", unsafe_allow_html=True)
199
+ max_scale = int(data.shape[0] // 128)
200
+ scale = st.selectbox('Scale:',[f"{(i+1)*128}x{(i+1)*128}" for i in range(max_scale)], label_visibility="hidden", on_change=change_scale)
201
+ scale = int(scale.split("x")[0]) // 128
202
+
203
+ # Detect button
204
+ with col3: detect = st.button('Detect', key="detect")
205
+
206
+ # Threshold slider
207
+ with col4:
208
+ st.markdown("")
209
+ # st.markdown("""<style>[data-baseweb="select"] {margin-top: -36px;}</style>""", unsafe_allow_html=True)
210
+ threshold = st.slider("Threshold", 0.0, 1.0, 0.0, 0.05, key="threshold") #, label_visibility="hidden")
211
 
212
+ # Decompose button
213
+ with col5: decompose = st.button('Decompose', key="decompose")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
+ # Make two columns for plots
216
+ _, colA, colB, colC, _ = st.columns([bordersize,1,1,1,bordersize])
217
+
218
+ image = np.log10(data+1)
219
+ plot_image(image, scale)
 
 
220
 
221
+ if detect or threshold:
222
+ # if st.session_state.get("detect", True):
223
+ y_pred, wcs = cut_n_predict(data, wcs, scale)
224
 
225
+ y_pred_th = np.where(y_pred > threshold, y_pred, 0)
226
+
227
+ plot_prediction(y_pred_th)
228
 
229
+ if decompose or st.session_state.get("download", False):
230
+ image_decomposed = decompose_cavity(y_pred_th)
231
+
232
+ plot_decomposed(image_decomposed)
233
+
234
+ with col6:
235
+ st.markdown("<br style='margin:4px 0'>", unsafe_allow_html=True)
236
+ # st.markdown("""<style>[data-baseweb="select"] {margin-top: 16px;}</style>""", unsafe_allow_html=True)
237
+ fname = uploaded_file.name.strip(".fits")
238
+
239
+ # if st.session_state.get("download", False):
240
+
241
+ shutil.make_archive("predictions", 'zip', "predictions")
242
+ with open('predictions.zip', 'rb') as f:
243
+ res = f.read()
244
+
245
+ download = st.download_button(label="Download", data=res, key="download",
246
+ file_name=f'{fname}_{int(scale*128)}.zip',
247
+ # disabled=st.session_state.get("disabled", True),
248
+ mime="application/octet-stream")