Plsek commited on
Commit
9bf2a3b
·
1 Parent(s): 3ca195b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -50
app.py CHANGED
@@ -20,15 +20,17 @@ st.set_option('deprecation.showPyplotGlobalUse', False)
20
 
21
  st.set_page_config(page_title="Cavity Detection Tool", layout="wide")
22
 
23
- # st.title("Cavity Detection Tool")
24
 
25
- st.markdown("Cavity Detection Tool (CADET) is a machine learning pipeline trained to detect X-ray cavities from noisy Chandra images of early-type galaxies. To use this tool: upload your image, select the scale of interest, and make a prediction! If you use this tool for your research, please cite [Plšek et al. 2023](https://arxiv.org/abs/2304.05457)")
26
 
27
- st.markdown("Input images should be centered at the centre of the galaxy and point sources should be filled with surrounding background ([dmfilth](https://cxc.cfa.harvard.edu/ciao/ahelp/dmfilth.html)).")
28
-
29
-
30
- # Create file uploader widget
31
- uploaded_file = st.file_uploader("Choose a FITS file", type=['fits'])
 
 
32
 
33
  # Define function to plot the uploaded image
34
  def plot_image(image, scale):
@@ -104,56 +106,56 @@ if uploaded_file is not None:
104
  data = hdul[0].data
105
  wcs = WCS(hdul[0].header)
106
 
107
- # Make four columns for buttons
108
- col1, col2, col3, col4 = st.columns(4)
109
- col1.subheader("Input image")
110
- col3.subheader("Prediction")
111
 
112
- with col1:
113
- st.markdown("""<style>[data-baseweb="select"] {margin-top: 16px;}</style>""", unsafe_allow_html=True)
114
- max_scale = int(data.shape[0] // 128)
115
- # scale = int(st.selectbox('Scale:',[i+1 for i in range(max_scale)], label_visibility="hidden"))
116
- scale = int(st.selectbox('Scale:',[f"{(i+1)*128}x{(i+1)*128}" for i in range(max_scale)], label_visibility="hidden"))
117
- scale = scale.split("x")[0] // 128
118
 
119
- with col2:
120
- detect = st.button('Detect cavities')
121
 
122
- with col3:
123
- decompose = st.button('Docompose cavities')
124
 
125
- # Make two columns for plots
126
- colA, colB = st.columns(2)
127
 
128
- image = np.log10(data+1)
129
- plot_image(image, scale)
130
 
131
- if detect:
132
- data, wcs = cut(data, wcs, scale=scale)
133
- image = np.log10(data+1)
134
 
135
- y_pred = 0
136
- for j in [0,1,2,3]:
137
- rotated = np.rot90(image, j)
138
- pred = model.predict(rotated.reshape(1, 128, 128, 1)).reshape(128 ,128)
139
- pred = np.rot90(pred, -j)
140
- y_pred += pred / 4
141
-
142
- # Thresholding
143
- y_pred = np.where(y_pred > 0.4, y_pred, 0)
144
-
145
- # if decompose:
146
- # cavs = decompose_cavity(y_pred, )
147
 
148
- plot_prediction(y_pred, decompose)
149
 
150
- ccd = CCDData(y_pred, unit="adu", wcs=wcs)
151
- ccd.write("predicted.fits", overwrite=True)
152
- with open('predicted.fits', 'rb') as f:
153
- res = f.read()
154
 
155
- with col4:
156
- pass
157
- st.markdown("""<style>[data-baseweb="select"] {margin-top: 16px;}</style>""", unsafe_allow_html=True)
158
- # # download = st.button('Download')
159
- download = st.download_button(label="Download", data=res, file_name="predicted.fits", mime="application/octet-stream")
 
20
 
21
  st.set_page_config(page_title="Cavity Detection Tool", layout="wide")
22
 
23
+ st.title("Cavity Detection Tool")
24
 
25
+ _, col, _ = st.columns([1, 3, 1])
26
 
27
+ with col:
28
+ st.markdown("Cavity Detection Tool (CADET) is a machine learning pipeline trained to detect X-ray cavities from noisy Chandra images of early-type galaxies. To use this tool: upload your image, select the scale of interest, and make a prediction! If you use this tool for your research, please cite [Plšek et al. 2023](https://arxiv.org/abs/2304.05457)")
29
+
30
+ st.markdown("Input images should be centered at the centre of the galaxy and point sources should be filled with surrounding background ([dmfilth](https://cxc.cfa.harvard.edu/ciao/ahelp/dmfilth.html)).")
31
+
32
+ # Create file uploader widget
33
+ uploaded_file = st.file_uploader("Choose a FITS file", type=['fits'])
34
 
35
  # Define function to plot the uploaded image
36
  def plot_image(image, scale):
 
106
  data = hdul[0].data
107
  wcs = WCS(hdul[0].header)
108
 
109
+ # # Make four columns for buttons
110
+ # col1, col2, col3, col4 = st.columns(4)
111
+ # col1.subheader("Input image")
112
+ # col3.subheader("Prediction")
113
 
114
+ # with col1:
115
+ # st.markdown("""<style>[data-baseweb="select"] {margin-top: 16px;}</style>""", unsafe_allow_html=True)
116
+ # max_scale = int(data.shape[0] // 128)
117
+ # # scale = int(st.selectbox('Scale:',[i+1 for i in range(max_scale)], label_visibility="hidden"))
118
+ # scale = int(st.selectbox('Scale:',[f"{(i+1)*128}x{(i+1)*128}" for i in range(max_scale)], label_visibility="hidden"))
119
+ # scale = scale.split("x")[0] // 128
120
 
121
+ # with col2:
122
+ # detect = st.button('Detect cavities')
123
 
124
+ # with col3:
125
+ # decompose = st.button('Docompose cavities')
126
 
127
+ # # Make two columns for plots
128
+ # colA, colB = st.columns(2)
129
 
130
+ # image = np.log10(data+1)
131
+ # plot_image(image, scale)
132
 
133
+ # if detect:
134
+ # data, wcs = cut(data, wcs, scale=scale)
135
+ # image = np.log10(data+1)
136
 
137
+ # y_pred = 0
138
+ # for j in [0,1,2,3]:
139
+ # rotated = np.rot90(image, j)
140
+ # pred = model.predict(rotated.reshape(1, 128, 128, 1)).reshape(128 ,128)
141
+ # pred = np.rot90(pred, -j)
142
+ # y_pred += pred / 4
143
+
144
+ # # Thresholding
145
+ # y_pred = np.where(y_pred > 0.4, y_pred, 0)
146
+
147
+ # # if decompose:
148
+ # # cavs = decompose_cavity(y_pred, )
149
 
150
+ # plot_prediction(y_pred, decompose)
151
 
152
+ # ccd = CCDData(y_pred, unit="adu", wcs=wcs)
153
+ # ccd.write("predicted.fits", overwrite=True)
154
+ # with open('predicted.fits', 'rb') as f:
155
+ # res = f.read()
156
 
157
+ # with col4:
158
+ # pass
159
+ # st.markdown("""<style>[data-baseweb="select"] {margin-top: 16px;}</style>""", unsafe_allow_html=True)
160
+ # # # download = st.button('Download')
161
+ # download = st.download_button(label="Download", data=res, file_name="predicted.fits", mime="application/octet-stream")