Plsek commited on
Commit
e52c641
·
1 Parent(s): 2331f7f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -11
app.py CHANGED
@@ -3,6 +3,8 @@ from huggingface_hub import from_pretrained_keras
3
  model = from_pretrained_keras("Plsek/CADET-v1")
4
 
5
  # Basic libraries
 
 
6
  import numpy as np
7
  import matplotlib.pyplot as plt
8
  from matplotlib.patches import Rectangle
@@ -20,19 +22,20 @@ from sklearn.cluster import DBSCAN
20
  # Streamlit
21
  import streamlit as st
22
  st.set_option('deprecation.showPyplotGlobalUse', False)
23
-
24
  st.set_page_config(page_title="Cavity Detection Tool", layout="wide")
25
-
26
  # st.title("Cavity Detection Tool")
27
 
28
  bordersize = 0.6
29
  _, col, _ = st.columns([bordersize, 3, bordersize])
30
 
 
 
31
  with col:
32
  st.markdown("# Cavity Detection Tool")
33
 
34
- st.markdown("Cavity Detection Tool (CADET) is a machine learning pipeline trained to detect X-ray cavities from noisy Chandra images of early-type galaxies.<br> To use this tool: upload your image, select the scale of interest, and make a prediction!<br> If you use this tool for your research, please cite [Plšek et al. 2023](https://arxiv.org/abs/2304.05457)")
35
-
 
36
  st.markdown("Input images should be centered at the centre of the galaxy and point sources should be filled with surrounding background ([dmfilth](https://cxc.cfa.harvard.edu/ciao/ahelp/dmfilth.html)).")
37
 
38
  # Create file uploader widget
@@ -183,17 +186,17 @@ if uploaded_file is not None:
183
  for i, cav in enumerate(cavs):
184
  ccd = CCDData(cav, unit="adu", wcs=wcs)
185
  ccd.write(f"predicted_{i+1}.fits", overwrite=True)
 
186
 
187
- image_decomposed += (i+1) * cav
188
-
189
  np.save("decomposed.npy", image_decomposed)
190
 
191
  try: image_decomposed = np.load("decomposed.npy")
192
  except: image_decomposed = np.zeros((128,128))
193
  plot_decomposed(image_decomposed)
194
 
195
- # with col4:
196
- # pass
197
- # st.markdown("""<style>[data-baseweb="select"] {margin-top: 16px;}</style>""", unsafe_allow_html=True)
198
- # # # download = st.button('Download')
199
- # download = st.download_button(label="Download", data=res, file_name="predicted.fits", mime="application/octet-stream")
 
3
  model = from_pretrained_keras("Plsek/CADET-v1")
4
 
5
  # Basic libraries
6
+ import os
7
+ import shutil
8
  import numpy as np
9
  import matplotlib.pyplot as plt
10
  from matplotlib.patches import Rectangle
 
22
  # Streamlit
23
  import streamlit as st
24
  st.set_option('deprecation.showPyplotGlobalUse', False)
 
25
  st.set_page_config(page_title="Cavity Detection Tool", layout="wide")
 
26
  # st.title("Cavity Detection Tool")
27
 
28
  bordersize = 0.6
29
  _, col, _ = st.columns([bordersize, 3, bordersize])
30
 
31
+ os.system("mkdir predictions")
32
+
33
  with col:
34
  st.markdown("# Cavity Detection Tool")
35
 
36
+ st.markdown("Cavity Detection Tool (CADET) is a machine learning pipeline trained to detect X-ray cavities from noisy Chandra images of early-type galaxies.")
37
+ st.markdown("To use this tool: upload your image, select the scale of interest, and make a prediction!")
38
+ st.markdown("If you use this tool for your research, please cite [Plšek et al. 2023](https://arxiv.org/abs/2304.05457)")
39
  st.markdown("Input images should be centered at the centre of the galaxy and point sources should be filled with surrounding background ([dmfilth](https://cxc.cfa.harvard.edu/ciao/ahelp/dmfilth.html)).")
40
 
41
  # Create file uploader widget
 
186
  for i, cav in enumerate(cavs):
187
  ccd = CCDData(cav, unit="adu", wcs=wcs)
188
  ccd.write(f"predicted_{i+1}.fits", overwrite=True)
189
+ image_decomposed += (i+1) * np.where(cav > 0, 1, 0)
190
 
191
+ shutil.make_archive("predictions", 'zip', "predictions")
 
192
  np.save("decomposed.npy", image_decomposed)
193
 
194
  try: image_decomposed = np.load("decomposed.npy")
195
  except: image_decomposed = np.zeros((128,128))
196
  plot_decomposed(image_decomposed)
197
 
198
+ with col6:
199
+ pass
200
+ # st.markdown("""<style>[data-baseweb="select"] {margin-top: 16px;}</style>""", unsafe_allow_html=True)
201
+ # # download = st.button('Download')
202
+ download = st.download_button(label="Download", data=res, file_name="predicted.zip", mime="application/octet-stream")