Plsek commited on
Commit
2af3442
·
1 Parent(s): 58946f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -57
app.py CHANGED
@@ -184,65 +184,63 @@ with col:
184
  if uploaded_file is not None:
185
  data, wcs = load_file(uploaded_file)
186
 
187
- if "data" not in locals():
188
- data = np.zeros((128,128))
189
-
190
- # Make six columns for buttons
191
- _, col1, col2, col3, col4, col5, col6, _ = st.columns([bordersize,0.5,0.5,0.5,0.5,0.5,0.5,bordersize])
192
- col1.subheader("Input image")
193
- col3.subheader("Prediction")
194
- col5.subheader("Decomposed")
195
- col6.subheader("")
196
-
197
- with col1:
198
- st.markdown("""<style>[data-baseweb="select"] {margin-top: -46px;}</style>""", unsafe_allow_html=True)
199
- max_scale = int(data.shape[0] // 128)
200
- scale = st.selectbox('Scale:',[f"{(i+1)*128}x{(i+1)*128}" for i in range(max_scale)], label_visibility="hidden", on_change=change_scale)
201
- scale = int(scale.split("x")[0]) // 128
202
-
203
- # Detect button
204
- with col3: detect = st.button('Detect', key="detect")
205
-
206
- # Threshold slider
207
- with col4:
208
- st.markdown("")
209
- # st.markdown("""<style>[data-baseweb="select"] {margin-top: -36px;}</style>""", unsafe_allow_html=True)
210
- threshold = st.slider("Threshold", 0.0, 1.0, 0.0, 0.05, key="threshold") #, label_visibility="hidden")
211
 
212
- # Decompose button
213
- with col5: decompose = st.button('Decompose', key="decompose")
 
 
 
214
 
215
- # Make two columns for plots
216
- _, colA, colB, colC, _ = st.columns([bordersize,1,1,1,bordersize])
217
-
218
- image = np.log10(data+1)
219
- plot_image(image, scale)
220
-
221
- if detect or threshold:
222
- # if st.session_state.get("detect", True):
223
- y_pred, wcs = cut_n_predict(data, wcs, scale)
224
 
225
- y_pred_th = np.where(y_pred > threshold, y_pred, 0)
226
-
227
- plot_prediction(y_pred_th)
228
-
229
- if decompose or st.session_state.get("download", False):
230
- image_decomposed = decompose_cavity(y_pred_th)
231
-
232
- plot_decomposed(image_decomposed)
233
-
234
- with col6:
235
- st.markdown("<br style='margin:4px 0'>", unsafe_allow_html=True)
236
- # st.markdown("""<style>[data-baseweb="select"] {margin-top: 16px;}</style>""", unsafe_allow_html=True)
237
- fname = uploaded_file.name.strip(".fits")
238
 
239
- # if st.session_state.get("download", False):
 
 
 
 
 
 
 
 
 
 
 
 
240
 
241
- shutil.make_archive("predictions", 'zip', "predictions")
242
- with open('predictions.zip', 'rb') as f:
243
- res = f.read()
244
-
245
- download = st.download_button(label="Download", data=res, key="download",
246
- file_name=f'{fname}_{int(scale*128)}.zip',
247
- # disabled=st.session_state.get("disabled", True),
248
- mime="application/octet-stream")
 
 
 
 
 
 
 
 
 
 
184
  if uploaded_file is not None:
185
  data, wcs = load_file(uploaded_file)
186
 
187
+ if "data" in locals():
188
+ # Make six columns for buttons
189
+ _, col1, col2, col3, col4, col5, col6, _ = st.columns([bordersize,0.5,0.5,0.5,0.5,0.5,0.5,bordersize])
190
+ col1.subheader("Input image")
191
+ col3.subheader("Prediction")
192
+ col5.subheader("Decomposed")
193
+ col6.subheader("")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
+ with col1:
196
+ st.markdown("""<style>[data-baseweb="select"] {margin-top: -46px;}</style>""", unsafe_allow_html=True)
197
+ max_scale = int(data.shape[0] // 128)
198
+ scale = st.selectbox('Scale:',[f"{(i+1)*128}x{(i+1)*128}" for i in range(max_scale)], label_visibility="hidden", on_change=change_scale)
199
+ scale = int(scale.split("x")[0]) // 128
200
 
201
+ # Detect button
202
+ with col3: detect = st.button('Detect', key="detect")
 
 
 
 
 
 
 
203
 
204
+ # Threshold slider
205
+ with col4:
206
+ st.markdown("")
207
+ # st.markdown("""<style>[data-baseweb="select"] {margin-top: -36px;}</style>""", unsafe_allow_html=True)
208
+ threshold = st.slider("Threshold", 0.0, 1.0, 0.0, 0.05, key="threshold") #, label_visibility="hidden")
209
+
210
+ # Decompose button
211
+ with col5: decompose = st.button('Decompose', key="decompose")
212
+
213
+ # Make two columns for plots
214
+ _, colA, colB, colC, _ = st.columns([bordersize,1,1,1,bordersize])
 
 
215
 
216
+ image = np.log10(data+1)
217
+ plot_image(image, scale)
218
+
219
+ if detect or threshold:
220
+ # if st.session_state.get("detect", True):
221
+ y_pred, wcs = cut_n_predict(data, wcs, scale)
222
+
223
+ y_pred_th = np.where(y_pred > threshold, y_pred, 0)
224
+
225
+ plot_prediction(y_pred_th)
226
+
227
+ if decompose or st.session_state.get("download", False):
228
+ image_decomposed = decompose_cavity(y_pred_th)
229
 
230
+ plot_decomposed(image_decomposed)
231
+
232
+ with col6:
233
+ st.markdown("<br style='margin:4px 0'>", unsafe_allow_html=True)
234
+ # st.markdown("""<style>[data-baseweb="select"] {margin-top: 16px;}</style>""", unsafe_allow_html=True)
235
+ fname = uploaded_file.name.strip(".fits")
236
+
237
+ # if st.session_state.get("download", False):
238
+
239
+ shutil.make_archive("predictions", 'zip', "predictions")
240
+ with open('predictions.zip', 'rb') as f:
241
+ res = f.read()
242
+
243
+ download = st.download_button(label="Download", data=res, key="download",
244
+ file_name=f'{fname}_{int(scale*128)}.zip',
245
+ # disabled=st.session_state.get("disabled", True),
246
+ mime="application/octet-stream")