Plsek commited on
Commit
e9ef2e7
·
1 Parent(s): af9ea07

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -31
app.py CHANGED
@@ -27,24 +27,6 @@ st.set_option('deprecation.showPyplotGlobalUse', False)
27
  st.set_page_config(page_title="Cavity Detection Tool", layout="wide")
28
  # st.title("Cavity Detection Tool")
29
 
30
- bordersize = 0.6
31
- _, col, _ = st.columns([bordersize, 3, bordersize])
32
-
33
- os.system("rm -r predictions")
34
- os.system("rm predictions.zip Views")
35
- os.system("mkdir predictions")
36
-
37
- with col:
38
- st.markdown("# Cavity Detection Tool")
39
-
40
- st.markdown("Cavity Detection Tool (CADET) is a machine learning pipeline trained to detect X-ray cavities from noisy Chandra images of early-type galaxies.")
41
- st.markdown("To use this tool: upload your image, select the scale of interest, and make a prediction!")
42
- st.markdown("Input images should be centred at the centre of the galaxy and point sources should be filled with surrounding background ([dmfilth](https://cxc.cfa.harvard.edu/ciao/ahelp/dmfilth.html)).")
43
- st.markdown("If you use this tool for your research, please cite [Plšek et al. 2023](https://arxiv.org/abs/2304.05457)")
44
-
45
- # Create file uploader widget
46
- uploaded_file = st.file_uploader("Choose a FITS file", type=['fits'])
47
-
48
  # Define function to plot the uploaded image
49
  def plot_image(image, scale):
50
  plt.figure(figsize=(4, 4))
@@ -125,6 +107,27 @@ def decompose_cavity(pred, th2=0.7, amin=10):
125
  cavities.append(img)
126
 
127
  return cavities
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
  # If file is uploaded, read in the data and plot it
130
  if uploaded_file is not None:
@@ -171,40 +174,44 @@ if uploaded_file is not None:
171
  pred = np.rot90(pred, -j)
172
  y_pred += pred / 4
173
 
174
- np.save("pred.npy", y_pred)
175
 
176
- try: y_pred = np.load("pred.npy")
 
 
177
  except: y_pred = np.zeros((128,128))
178
  y_pred = np.where(y_pred > threshold, y_pred, 0)
179
- np.save("thresh.npy", y_pred)
180
 
181
  plot_prediction(y_pred)
182
 
183
  if decompose:
184
- y_pred = np.load("thresh.npy")
185
 
186
  cavs = decompose_cavity(y_pred)
187
 
188
  ccd = CCDData(y_pred, unit="adu", wcs=wcs)
189
- ccd.write(f"predictions/predicted.fits", overwrite=True)
190
  image_decomposed = np.zeros((128,128))
191
  for i, cav in enumerate(cavs):
192
- ccd = CCDData(cav, unit="adu", wcs=wcs)
193
- ccd.write(f"predictions/predicted_{i+1}.fits", overwrite=True)
194
  image_decomposed += (i+1) * np.where(cav > 0, 1, 0)
195
 
196
  # shutil.make_archive("predictions.zip", 'zip', "predictions")
197
- np.save("decomposed.npy", image_decomposed)
198
 
199
- try: image_decomposed = np.load("decomposed.npy")
 
 
200
  except: image_decomposed = np.zeros((128,128))
201
  plot_decomposed(image_decomposed)
202
 
203
- shutil.make_archive("predictions", 'zip', "predictions")
204
 
205
  with col6:
206
- with open('predictions.zip', 'rb') as f:
207
- res = f.read()
208
  # # st.markdown("""<style>[data-baseweb="select"] {margin-top: 16px;}</style>""", unsafe_allow_html=True)
209
  # # # download = st.button('Download')
210
- download = st.download_button(label="Download", data=res, file_name='predictions.zip', mime="application/octet-stream")
 
27
  st.set_page_config(page_title="Cavity Detection Tool", layout="wide")
28
  # st.title("Cavity Detection Tool")
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  # Define function to plot the uploaded image
31
  def plot_image(image, scale):
32
  plt.figure(figsize=(4, 4))
 
107
  cavities.append(img)
108
 
109
  return cavities
110
+
111
+
112
+
113
+
114
+ bordersize = 0.6
115
+ _, col, _ = st.columns([bordersize, 3, bordersize])
116
+
117
+ # os.system("rm -r predictions")
118
+ # os.system("rm predictions.zip Views")
119
+ # os.system("mkdir -p predictions")
120
+
121
+ with col:
122
+ st.markdown("# Cavity Detection Tool")
123
+
124
+ st.markdown("Cavity Detection Tool (CADET) is a machine learning pipeline trained to detect X-ray cavities from noisy Chandra images of early-type galaxies.")
125
+ st.markdown("To use this tool: upload your image, select the scale of interest, and make a prediction!")
126
+ st.markdown("Input images should be centred at the centre of the galaxy and point sources should be filled with surrounding background ([dmfilth](https://cxc.cfa.harvard.edu/ciao/ahelp/dmfilth.html)).")
127
+ st.markdown("If you use this tool for your research, please cite [Plšek et al. 2023](https://arxiv.org/abs/2304.05457)")
128
+
129
+ # Create file uploader widget
130
+ uploaded_file = st.file_uploader("Choose a FITS file", type=['fits'])
131
 
132
  # If file is uploaded, read in the data and plot it
133
  if uploaded_file is not None:
 
174
  pred = np.rot90(pred, -j)
175
  y_pred += pred / 4
176
 
177
+ # np.save("pred.npy", y_pred)
178
 
179
+ # try: y_pred = np.load("pred.npy")
180
+ # except: y_pred = np.zeros((128,128))
181
+ try: y_pred
182
  except: y_pred = np.zeros((128,128))
183
  y_pred = np.where(y_pred > threshold, y_pred, 0)
184
+ # np.save("thresh.npy", y_pred)
185
 
186
  plot_prediction(y_pred)
187
 
188
  if decompose:
189
+ # y_pred = np.load("thresh.npy")
190
 
191
  cavs = decompose_cavity(y_pred)
192
 
193
  ccd = CCDData(y_pred, unit="adu", wcs=wcs)
194
+ # ccd.write(f"predictions/predicted.fits", overwrite=True)
195
  image_decomposed = np.zeros((128,128))
196
  for i, cav in enumerate(cavs):
197
+ # ccd = CCDData(cav, unit="adu", wcs=wcs)
198
+ # ccd.write(f"predictions/predicted_{i+1}.fits", overwrite=True)
199
  image_decomposed += (i+1) * np.where(cav > 0, 1, 0)
200
 
201
  # shutil.make_archive("predictions.zip", 'zip', "predictions")
202
+ # np.save("decomposed.npy", image_decomposed)
203
 
204
+ # try: image_decomposed = np.load("decomposed.npy")
205
+ # except: image_decomposed = np.zeros((128,128))
206
+ try: image_decomposed
207
  except: image_decomposed = np.zeros((128,128))
208
  plot_decomposed(image_decomposed)
209
 
210
+ # shutil.make_archive("predictions", 'zip', "predictions")
211
 
212
  with col6:
213
+ # with open('predictions.zip', 'rb') as f:
214
+ # res = f.read()
215
  # # st.markdown("""<style>[data-baseweb="select"] {margin-top: 16px;}</style>""", unsafe_allow_html=True)
216
  # # # download = st.button('Download')
217
+ download = st.download_button(label="Download", data=ccd, file_name='prediction.fits', mime="application/octet-stream")