Plsek commited on
Commit
08bba4c
·
1 Parent(s): 7a50c74

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +186 -186
app.py CHANGED
@@ -25,224 +25,224 @@ from sklearn.cluster import DBSCAN
25
  import streamlit as st
26
  st.set_option('deprecation.showPyplotGlobalUse', False)
27
 
28
- # Define function to plot the uploaded image
29
- def plot_image(image, scale):
30
- plt.figure(figsize=(4, 4))
31
- x0 = image.shape[0] // 2 - scale * 128 / 2
32
- plt.imshow(image, origin="lower")
33
- plt.gca().add_patch(Rectangle((x0-0.5, x0-0.5), scale*128, scale*128, linewidth=1, edgecolor='w', facecolor='none'))
34
- plt.axis('off')
35
- plt.tight_layout()
36
- with colA: st.pyplot()
37
-
38
- # Define function to plot the prediction
39
- def plot_prediction(pred):
40
- plt.figure(figsize=(4, 4))
41
- plt.imshow(pred, origin="lower")
42
- plt.axis('off')
43
- with colB: st.pyplot()
44
-
45
- # Define function to plot the decomposed prediction
46
- def plot_decomposed(decomposed):
47
- plt.figure(figsize=(4, 4))
48
- plt.imshow(decomposed, origin="lower") #, norm=LogNorm())
49
-
50
- N = int(np.max(decomposed))
51
- for i in range(N):
52
- new = np.where(decomposed == i+1, 1, 0)
53
- x0, y0 = center_of_mass(new)
54
- color = "white" if i < N//2 else "black"
55
- plt.text(y0, x0, f"{i+1}", ha="center", va="center", fontsize=15, color=color)
56
 
57
- plt.axis('off')
58
- with colC: st.pyplot()
59
 
60
- # Define function to cut input image and rebin it to 128x128 pixels
61
- def cut(data0, wcs0, scale=1):
62
- shape = data0.shape[0]
63
- x0 = shape / 2
64
- size = 128 * scale
65
- cutout = Cutout2D(data0, (x0, x0), (size, size), wcs=wcs0)
66
- data, wcs = cutout.data, cutout.wcs
67
-
68
- # Regrid data
69
- factor = size // 128
70
- data = data.reshape(128, factor, 128, factor).mean(-1).mean(1)
71
 
72
- # Regrid wcs
73
- ra, dec = wcs.wcs_pix2world(np.array([[63, 63]]),0)[0]
74
- wcs.wcs.cdelt[0] = wcs.wcs.cdelt[0] * factor
75
- wcs.wcs.cdelt[1] = wcs.wcs.cdelt[1] * factor
76
- wcs.wcs.crval[0] = ra
77
- wcs.wcs.crval[1] = dec
78
- wcs.wcs.crpix[0] = 64 / factor
79
- wcs.wcs.crpix[1] = 64 / factor
80
-
81
- return data, wcs
82
-
83
- # Define function to apply cutting and produce a prediction
84
- # @st.cache
85
- def cut_n_predict(data, wcs, scale):
86
- data, wcs = cut(data, wcs, scale=scale)
87
- image = np.log10(data+1)
88
 
89
- y_pred = 0
90
- for j in [0,1,2,3]:
91
- rotated = np.rot90(image, j)
92
- pred = model.predict(rotated.reshape(1, 128, 128, 1)).reshape(128 ,128)
93
- pred = np.rot90(pred, -j)
94
- y_pred += pred / 4
95
-
96
- return y_pred, wcs
97
-
98
- # Define function to decompose prediction into individual cavities
99
- # @st.cache
100
- def decompose_cavity(pred, th2=0.7, amin=6):
101
- X, Y = pred.nonzero()
102
- data = np.array([X,Y]).reshape(2, -1)
103
-
104
- # DBSCAN clustering
105
- try: clusters = DBSCAN(eps=1.0, min_samples=3).fit(data.T).labels_
106
- except: clusters = []
107
-
108
- N = len(set(clusters))
109
- cavities = []
110
-
111
- for i in range(N):
112
- img = np.zeros((128,128))
113
- b = clusters == i
114
- xi, yi = X[b], Y[b]
115
- img[xi, yi] = pred[xi, yi]
116
-
117
- # # Thresholding #2
118
- # if not (img > th2).any(): continue
119
-
120
- # Minimal area
121
- if np.sum(img) <= amin: continue
122
-
123
- cavities.append(img)
124
-
125
- # Save raw and decomposed predictions to predictions folder
126
- ccd = CCDData(pred, unit="adu", wcs=wcs)
127
- ccd.write(f"predictions/predicted.fits", overwrite=True)
128
- image_decomposed = np.zeros((128,128))
129
- for i, cav in enumerate(cavities):
130
- ccd = CCDData(cav, unit="adu", wcs=wcs)
131
- ccd.write(f"predictions/predicted_{i+1}.fits", overwrite=True)
132
- image_decomposed += (i+1) * np.where(cav > 0, 1, 0)
133
-
134
- # shutil.make_archive("predictions", 'zip', "predictions")
135
 
136
- return image_decomposed
137
 
138
- # @st.cache
139
- def load_file(fname):
140
- with fits.open(fname) as hdul:
141
- data = hdul[0].data
142
- wcs = WCS(hdul[0].header)
143
- return data, wcs
144
 
145
- def change_scale():
146
- del st.session_state["threshold"]
147
 
148
 
149
 
150
- # Use wide layout and create columns
151
- st.set_page_config(page_title="Cavity Detection Tool", layout="wide")
152
- bordersize = 0.45
153
- _, col, _ = st.columns([bordersize, 3, bordersize])
154
 
155
- os.system("mkdir -p predictions")
156
 
157
- with col:
158
- # Create heading and description
159
- st.markdown("<h1 align='center'>Cavity Detection Tool</h1>", unsafe_allow_html=True)
160
- st.markdown("Cavity Detection Tool (CADET) is a machine learning pipeline trained to detect X-ray cavities from noisy Chandra images of early-type galaxies.")
161
- st.markdown("To use this tool: upload your image, select the scale of interest, make a prediction, and decompose it into individual cavities!")
162
- st.markdown("Input images should be in units of counts, centred at the galaxy center, and point sources should be filled with surrounding background ([dmfilth](https://cxc.cfa.harvard.edu/ciao/ahelp/dmfilth.html)).")
163
- st.markdown("If you use this tool for your research, please cite [Plšek et al. 2023](https://arxiv.org/abs/2304.05457)")
164
 
165
- # _, col_1, col_2, col_3, _ = st.columns([bordersize, 2.0, 0.5, 0.5, bordersize])
166
 
167
- # with col:
168
- uploaded_file = st.file_uploader("Choose a FITS file", type=['fits'])
169
 
170
- # with col_2:
171
- # st.markdown("### Examples")
172
- # NGC4649 = st.button("NGC4649")
173
 
174
- # with col_3:
175
- # st.markdown("""<style>[data-baseweb="select"] {margin-top: 26px;}</style>""", unsafe_allow_html=True)
176
- # NGC5813 = st.button("NGC5813")
177
 
178
- # if NGC4649:
179
- # uploaded_file = "NGC4649_example.fits"
180
- # elif NGC5813:
181
- # uploaded_file = "NGC5813_example.fits"
182
 
183
- # # If file is uploaded, read in the data and plot it
184
- # if uploaded_file is not None:
185
- # data, wcs = load_file(uploaded_file)
186
 
187
- # # if "data" not in locals():
188
- # # data = np.zeros((128,128))
189
 
190
- # # Make six columns for buttons
191
- # _, col1, col2, col3, col4, col5, col6, _ = st.columns([bordersize,0.5,0.5,0.5,0.5,0.5,0.5,bordersize])
192
- # col1.subheader("Input image")
193
- # col3.subheader("Prediction")
194
- # col5.subheader("Decomposed")
195
- # col6.subheader("")
196
 
197
- # with col1:
198
- # st.markdown("""<style>[data-baseweb="select"] {margin-top: -46px;}</style>""", unsafe_allow_html=True)
199
- # max_scale = int(data.shape[0] // 128)
200
- # scale = st.selectbox('Scale:',[f"{(i+1)*128}x{(i+1)*128}" for i in range(max_scale)], label_visibility="hidden", on_change=change_scale)
201
- # scale = int(scale.split("x")[0]) // 128
202
 
203
- # # Detect button
204
- # with col3: detect = st.button('Detect', key="detect")
205
 
206
- # # Threshold slider
207
- # with col4:
208
- # st.markdown("")
209
- # # st.markdown("""<style>[data-baseweb="select"] {margin-top: -36px;}</style>""", unsafe_allow_html=True)
210
- # threshold = st.slider("Threshold", 0.0, 1.0, 0.0, 0.05, key="threshold") #, label_visibility="hidden")
211
 
212
- # # Decompose button
213
- # with col5: decompose = st.button('Decompose', key="decompose")
214
 
215
- # # Make two columns for plots
216
- # _, colA, colB, colC, _ = st.columns([bordersize,1,1,1,bordersize])
217
 
218
- # image = np.log10(data+1)
219
- # plot_image(image, scale)
220
 
221
- # if detect or threshold:
222
- # # if st.session_state.get("detect", True):
223
- # y_pred, wcs = cut_n_predict(data, wcs, scale)
224
 
225
- # y_pred_th = np.where(y_pred > threshold, y_pred, 0)
226
 
227
- # plot_prediction(y_pred_th)
228
 
229
- # if decompose or st.session_state.get("download", False):
230
- # image_decomposed = decompose_cavity(y_pred_th)
231
 
232
- # plot_decomposed(image_decomposed)
233
 
234
- # with col6:
235
- # st.markdown("<br style='margin:4px 0'>", unsafe_allow_html=True)
236
- # # st.markdown("""<style>[data-baseweb="select"] {margin-top: 16px;}</style>""", unsafe_allow_html=True)
237
- # fname = uploaded_file.name.strip(".fits")
238
 
239
- # # if st.session_state.get("download", False):
240
 
241
- # shutil.make_archive("predictions", 'zip', "predictions")
242
- # with open('predictions.zip', 'rb') as f:
243
- # res = f.read()
244
 
245
- # download = st.download_button(label="Download", data=res, key="download",
246
- # file_name=f'{fname}_{int(scale*128)}.zip',
247
- # # disabled=st.session_state.get("disabled", True),
248
- # mime="application/octet-stream")
 
25
  import streamlit as st
26
  st.set_option('deprecation.showPyplotGlobalUse', False)
27
 
28
+ # # Define function to plot the uploaded image
29
+ # def plot_image(image, scale):
30
+ # plt.figure(figsize=(4, 4))
31
+ # x0 = image.shape[0] // 2 - scale * 128 / 2
32
+ # plt.imshow(image, origin="lower")
33
+ # plt.gca().add_patch(Rectangle((x0-0.5, x0-0.5), scale*128, scale*128, linewidth=1, edgecolor='w', facecolor='none'))
34
+ # plt.axis('off')
35
+ # plt.tight_layout()
36
+ # with colA: st.pyplot()
37
+
38
+ # # Define function to plot the prediction
39
+ # def plot_prediction(pred):
40
+ # plt.figure(figsize=(4, 4))
41
+ # plt.imshow(pred, origin="lower")
42
+ # plt.axis('off')
43
+ # with colB: st.pyplot()
44
+
45
+ # # Define function to plot the decomposed prediction
46
+ # def plot_decomposed(decomposed):
47
+ # plt.figure(figsize=(4, 4))
48
+ # plt.imshow(decomposed, origin="lower") #, norm=LogNorm())
49
+
50
+ # N = int(np.max(decomposed))
51
+ # for i in range(N):
52
+ # new = np.where(decomposed == i+1, 1, 0)
53
+ # x0, y0 = center_of_mass(new)
54
+ # color = "white" if i < N//2 else "black"
55
+ # plt.text(y0, x0, f"{i+1}", ha="center", va="center", fontsize=15, color=color)
56
 
57
+ # plt.axis('off')
58
+ # with colC: st.pyplot()
59
 
60
+ # # Define function to cut input image and rebin it to 128x128 pixels
61
+ # def cut(data0, wcs0, scale=1):
62
+ # shape = data0.shape[0]
63
+ # x0 = shape / 2
64
+ # size = 128 * scale
65
+ # cutout = Cutout2D(data0, (x0, x0), (size, size), wcs=wcs0)
66
+ # data, wcs = cutout.data, cutout.wcs
67
+
68
+ # # Regrid data
69
+ # factor = size // 128
70
+ # data = data.reshape(128, factor, 128, factor).mean(-1).mean(1)
71
 
72
+ # # Regrid wcs
73
+ # ra, dec = wcs.wcs_pix2world(np.array([[63, 63]]),0)[0]
74
+ # wcs.wcs.cdelt[0] = wcs.wcs.cdelt[0] * factor
75
+ # wcs.wcs.cdelt[1] = wcs.wcs.cdelt[1] * factor
76
+ # wcs.wcs.crval[0] = ra
77
+ # wcs.wcs.crval[1] = dec
78
+ # wcs.wcs.crpix[0] = 64 / factor
79
+ # wcs.wcs.crpix[1] = 64 / factor
80
+
81
+ # return data, wcs
82
+
83
+ # # Define function to apply cutting and produce a prediction
84
+ # # @st.cache
85
+ # def cut_n_predict(data, wcs, scale):
86
+ # data, wcs = cut(data, wcs, scale=scale)
87
+ # image = np.log10(data+1)
88
 
89
+ # y_pred = 0
90
+ # for j in [0,1,2,3]:
91
+ # rotated = np.rot90(image, j)
92
+ # pred = model.predict(rotated.reshape(1, 128, 128, 1)).reshape(128 ,128)
93
+ # pred = np.rot90(pred, -j)
94
+ # y_pred += pred / 4
95
+
96
+ # return y_pred, wcs
97
+
98
+ # # Define function to decompose prediction into individual cavities
99
+ # # @st.cache
100
+ # def decompose_cavity(pred, th2=0.7, amin=6):
101
+ # X, Y = pred.nonzero()
102
+ # data = np.array([X,Y]).reshape(2, -1)
103
+
104
+ # # DBSCAN clustering
105
+ # try: clusters = DBSCAN(eps=1.0, min_samples=3).fit(data.T).labels_
106
+ # except: clusters = []
107
+
108
+ # N = len(set(clusters))
109
+ # cavities = []
110
+
111
+ # for i in range(N):
112
+ # img = np.zeros((128,128))
113
+ # b = clusters == i
114
+ # xi, yi = X[b], Y[b]
115
+ # img[xi, yi] = pred[xi, yi]
116
+
117
+ # # # Thresholding #2
118
+ # # if not (img > th2).any(): continue
119
+
120
+ # # Minimal area
121
+ # if np.sum(img) <= amin: continue
122
+
123
+ # cavities.append(img)
124
+
125
+ # # Save raw and decomposed predictions to predictions folder
126
+ # ccd = CCDData(pred, unit="adu", wcs=wcs)
127
+ # ccd.write(f"predictions/predicted.fits", overwrite=True)
128
+ # image_decomposed = np.zeros((128,128))
129
+ # for i, cav in enumerate(cavities):
130
+ # ccd = CCDData(cav, unit="adu", wcs=wcs)
131
+ # ccd.write(f"predictions/predicted_{i+1}.fits", overwrite=True)
132
+ # image_decomposed += (i+1) * np.where(cav > 0, 1, 0)
133
+
134
+ # # shutil.make_archive("predictions", 'zip', "predictions")
135
 
136
+ # return image_decomposed
137
 
138
+ # # @st.cache
139
+ # def load_file(fname):
140
+ # with fits.open(fname) as hdul:
141
+ # data = hdul[0].data
142
+ # wcs = WCS(hdul[0].header)
143
+ # return data, wcs
144
 
145
+ # def change_scale():
146
+ # del st.session_state["threshold"]
147
 
148
 
149
 
150
+ # # Use wide layout and create columns
151
+ # st.set_page_config(page_title="Cavity Detection Tool", layout="wide")
152
+ # bordersize = 0.45
153
+ # _, col, _ = st.columns([bordersize, 3, bordersize])
154
 
155
+ # os.system("mkdir -p predictions")
156
 
157
+ # with col:
158
+ # # Create heading and description
159
+ # st.markdown("<h1 align='center'>Cavity Detection Tool</h1>", unsafe_allow_html=True)
160
+ # st.markdown("Cavity Detection Tool (CADET) is a machine learning pipeline trained to detect X-ray cavities from noisy Chandra images of early-type galaxies.")
161
+ # st.markdown("To use this tool: upload your image, select the scale of interest, make a prediction, and decompose it into individual cavities!")
162
+ # st.markdown("Input images should be in units of counts, centred at the galaxy center, and point sources should be filled with surrounding background ([dmfilth](https://cxc.cfa.harvard.edu/ciao/ahelp/dmfilth.html)).")
163
+ # st.markdown("If you use this tool for your research, please cite [Plšek et al. 2023](https://arxiv.org/abs/2304.05457)")
164
 
165
+ # # _, col_1, col_2, col_3, _ = st.columns([bordersize, 2.0, 0.5, 0.5, bordersize])
166
 
167
+ # # with col:
168
+ # uploaded_file = st.file_uploader("Choose a FITS file", type=['fits'])
169
 
170
+ # # with col_2:
171
+ # # st.markdown("### Examples")
172
+ # # NGC4649 = st.button("NGC4649")
173
 
174
+ # # with col_3:
175
+ # # st.markdown("""<style>[data-baseweb="select"] {margin-top: 26px;}</style>""", unsafe_allow_html=True)
176
+ # # NGC5813 = st.button("NGC5813")
177
 
178
+ # # if NGC4649:
179
+ # # uploaded_file = "NGC4649_example.fits"
180
+ # # elif NGC5813:
181
+ # # uploaded_file = "NGC5813_example.fits"
182
 
183
+ # # # If file is uploaded, read in the data and plot it
184
+ # # if uploaded_file is not None:
185
+ # # data, wcs = load_file(uploaded_file)
186
 
187
+ # # # if "data" not in locals():
188
+ # # # data = np.zeros((128,128))
189
 
190
+ # # # Make six columns for buttons
191
+ # # _, col1, col2, col3, col4, col5, col6, _ = st.columns([bordersize,0.5,0.5,0.5,0.5,0.5,0.5,bordersize])
192
+ # # col1.subheader("Input image")
193
+ # # col3.subheader("Prediction")
194
+ # # col5.subheader("Decomposed")
195
+ # # col6.subheader("")
196
 
197
+ # # with col1:
198
+ # # st.markdown("""<style>[data-baseweb="select"] {margin-top: -46px;}</style>""", unsafe_allow_html=True)
199
+ # # max_scale = int(data.shape[0] // 128)
200
+ # # scale = st.selectbox('Scale:',[f"{(i+1)*128}x{(i+1)*128}" for i in range(max_scale)], label_visibility="hidden", on_change=change_scale)
201
+ # # scale = int(scale.split("x")[0]) // 128
202
 
203
+ # # # Detect button
204
+ # # with col3: detect = st.button('Detect', key="detect")
205
 
206
+ # # # Threshold slider
207
+ # # with col4:
208
+ # # st.markdown("")
209
+ # # # st.markdown("""<style>[data-baseweb="select"] {margin-top: -36px;}</style>""", unsafe_allow_html=True)
210
+ # # threshold = st.slider("Threshold", 0.0, 1.0, 0.0, 0.05, key="threshold") #, label_visibility="hidden")
211
 
212
+ # # # Decompose button
213
+ # # with col5: decompose = st.button('Decompose', key="decompose")
214
 
215
+ # # # Make two columns for plots
216
+ # # _, colA, colB, colC, _ = st.columns([bordersize,1,1,1,bordersize])
217
 
218
+ # # image = np.log10(data+1)
219
+ # # plot_image(image, scale)
220
 
221
+ # # if detect or threshold:
222
+ # # # if st.session_state.get("detect", True):
223
+ # # y_pred, wcs = cut_n_predict(data, wcs, scale)
224
 
225
+ # # y_pred_th = np.where(y_pred > threshold, y_pred, 0)
226
 
227
+ # # plot_prediction(y_pred_th)
228
 
229
+ # # if decompose or st.session_state.get("download", False):
230
+ # # image_decomposed = decompose_cavity(y_pred_th)
231
 
232
+ # # plot_decomposed(image_decomposed)
233
 
234
+ # # with col6:
235
+ # # st.markdown("<br style='margin:4px 0'>", unsafe_allow_html=True)
236
+ # # # st.markdown("""<style>[data-baseweb="select"] {margin-top: 16px;}</style>""", unsafe_allow_html=True)
237
+ # # fname = uploaded_file.name.strip(".fits")
238
 
239
+ # # # if st.session_state.get("download", False):
240
 
241
+ # # shutil.make_archive("predictions", 'zip', "predictions")
242
+ # # with open('predictions.zip', 'rb') as f:
243
+ # # res = f.read()
244
 
245
+ # # download = st.download_button(label="Download", data=res, key="download",
246
+ # # file_name=f'{fname}_{int(scale*128)}.zip',
247
+ # # # disabled=st.session_state.get("disabled", True),
248
+ # # mime="application/octet-stream")