Plsek commited on
Commit
4fed3df
·
1 Parent(s): 08bba4c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +214 -199
app.py CHANGED
@@ -1,14 +1,10 @@
1
- # HuggingFace Hub
2
- from huggingface_hub import from_pretrained_keras
3
- model = from_pretrained_keras("Plsek/CADET-v1")
4
-
5
  # Basic libraries
6
  import os
7
  import shutil
8
  import numpy as np
9
  from scipy.ndimage import center_of_mass
10
  import matplotlib.pyplot as plt
11
- from matplotlib.colors import LogNorm
12
  from matplotlib.patches import Rectangle
13
 
14
  # Astropy
@@ -24,225 +20,244 @@ from sklearn.cluster import DBSCAN
24
  # Streamlit
25
  import streamlit as st
26
  st.set_option('deprecation.showPyplotGlobalUse', False)
 
27
 
28
- # # Define function to plot the uploaded image
29
- # def plot_image(image, scale):
30
- # plt.figure(figsize=(4, 4))
31
- # x0 = image.shape[0] // 2 - scale * 128 / 2
32
- # plt.imshow(image, origin="lower")
33
- # plt.gca().add_patch(Rectangle((x0-0.5, x0-0.5), scale*128, scale*128, linewidth=1, edgecolor='w', facecolor='none'))
34
- # plt.axis('off')
35
- # plt.tight_layout()
36
- # with colA: st.pyplot()
37
-
38
- # # Define function to plot the prediction
39
- # def plot_prediction(pred):
40
- # plt.figure(figsize=(4, 4))
41
- # plt.imshow(pred, origin="lower")
42
- # plt.axis('off')
43
- # with colB: st.pyplot()
44
-
45
- # # Define function to plot the decomposed prediction
46
- # def plot_decomposed(decomposed):
47
- # plt.figure(figsize=(4, 4))
48
- # plt.imshow(decomposed, origin="lower") #, norm=LogNorm())
49
-
50
- # N = int(np.max(decomposed))
51
- # for i in range(N):
52
- # new = np.where(decomposed == i+1, 1, 0)
53
- # x0, y0 = center_of_mass(new)
54
- # color = "white" if i < N//2 else "black"
55
- # plt.text(y0, x0, f"{i+1}", ha="center", va="center", fontsize=15, color=color)
 
 
 
 
 
 
56
 
57
- # plt.axis('off')
58
- # with colC: st.pyplot()
59
 
60
- # # Define function to cut input image and rebin it to 128x128 pixels
61
- # def cut(data0, wcs0, scale=1):
62
- # shape = data0.shape[0]
63
- # x0 = shape / 2
64
- # size = 128 * scale
65
- # cutout = Cutout2D(data0, (x0, x0), (size, size), wcs=wcs0)
66
- # data, wcs = cutout.data, cutout.wcs
67
-
68
- # # Regrid data
69
- # factor = size // 128
70
- # data = data.reshape(128, factor, 128, factor).mean(-1).mean(1)
71
 
72
- # # Regrid wcs
73
- # ra, dec = wcs.wcs_pix2world(np.array([[63, 63]]),0)[0]
74
- # wcs.wcs.cdelt[0] = wcs.wcs.cdelt[0] * factor
75
- # wcs.wcs.cdelt[1] = wcs.wcs.cdelt[1] * factor
76
- # wcs.wcs.crval[0] = ra
77
- # wcs.wcs.crval[1] = dec
78
- # wcs.wcs.crpix[0] = 64 / factor
79
- # wcs.wcs.crpix[1] = 64 / factor
80
-
81
- # return data, wcs
82
-
83
- # # Define function to apply cutting and produce a prediction
84
- # # @st.cache
85
- # def cut_n_predict(data, wcs, scale):
86
- # data, wcs = cut(data, wcs, scale=scale)
87
- # image = np.log10(data+1)
88
 
89
- # y_pred = 0
90
- # for j in [0,1,2,3]:
91
- # rotated = np.rot90(image, j)
92
- # pred = model.predict(rotated.reshape(1, 128, 128, 1)).reshape(128 ,128)
93
- # pred = np.rot90(pred, -j)
94
- # y_pred += pred / 4
95
-
96
- # return y_pred, wcs
97
-
98
- # # Define function to decompose prediction into individual cavities
99
- # # @st.cache
100
- # def decompose_cavity(pred, th2=0.7, amin=6):
101
- # X, Y = pred.nonzero()
102
- # data = np.array([X,Y]).reshape(2, -1)
103
-
104
- # # DBSCAN clustering
105
- # try: clusters = DBSCAN(eps=1.0, min_samples=3).fit(data.T).labels_
106
- # except: clusters = []
107
-
108
- # N = len(set(clusters))
109
- # cavities = []
110
-
111
- # for i in range(N):
112
- # img = np.zeros((128,128))
113
- # b = clusters == i
114
- # xi, yi = X[b], Y[b]
115
- # img[xi, yi] = pred[xi, yi]
116
-
117
- # # # Thresholding #2
118
- # # if not (img > th2).any(): continue
119
-
120
- # # Minimal area
121
- # if np.sum(img) <= amin: continue
122
-
123
- # cavities.append(img)
124
-
125
- # # Save raw and decomposed predictions to predictions folder
126
- # ccd = CCDData(pred, unit="adu", wcs=wcs)
127
- # ccd.write(f"predictions/predicted.fits", overwrite=True)
128
- # image_decomposed = np.zeros((128,128))
129
- # for i, cav in enumerate(cavities):
130
- # ccd = CCDData(cav, unit="adu", wcs=wcs)
131
- # ccd.write(f"predictions/predicted_{i+1}.fits", overwrite=True)
132
- # image_decomposed += (i+1) * np.where(cav > 0, 1, 0)
133
-
134
- # # shutil.make_archive("predictions", 'zip', "predictions")
135
 
136
- # return image_decomposed
137
 
138
- # # @st.cache
139
- # def load_file(fname):
140
- # with fits.open(fname) as hdul:
141
- # data = hdul[0].data
142
- # wcs = WCS(hdul[0].header)
143
- # return data, wcs
144
 
145
- # def change_scale():
146
- # del st.session_state["threshold"]
 
 
 
147
 
 
 
148
 
149
 
150
- # # Use wide layout and create columns
151
- # st.set_page_config(page_title="Cavity Detection Tool", layout="wide")
152
- # bordersize = 0.45
153
- # _, col, _ = st.columns([bordersize, 3, bordersize])
154
 
 
 
 
 
 
 
 
155
  # os.system("mkdir -p predictions")
156
 
157
- # with col:
158
- # # Create heading and description
159
- # st.markdown("<h1 align='center'>Cavity Detection Tool</h1>", unsafe_allow_html=True)
160
- # st.markdown("Cavity Detection Tool (CADET) is a machine learning pipeline trained to detect X-ray cavities from noisy Chandra images of early-type galaxies.")
161
- # st.markdown("To use this tool: upload your image, select the scale of interest, make a prediction, and decompose it into individual cavities!")
162
- # st.markdown("Input images should be in units of counts, centred at the galaxy center, and point sources should be filled with surrounding background ([dmfilth](https://cxc.cfa.harvard.edu/ciao/ahelp/dmfilth.html)).")
163
- # st.markdown("If you use this tool for your research, please cite [Plšek et al. 2023](https://arxiv.org/abs/2304.05457)")
164
 
165
- # # _, col_1, col_2, col_3, _ = st.columns([bordersize, 2.0, 0.5, 0.5, bordersize])
166
 
167
- # # with col:
168
- # uploaded_file = st.file_uploader("Choose a FITS file", type=['fits'])
169
 
170
- # # with col_2:
171
- # # st.markdown("### Examples")
172
- # # NGC4649 = st.button("NGC4649")
173
 
174
- # # with col_3:
175
- # # st.markdown("""<style>[data-baseweb="select"] {margin-top: 26px;}</style>""", unsafe_allow_html=True)
176
- # # NGC5813 = st.button("NGC5813")
177
 
178
- # # if NGC4649:
179
- # # uploaded_file = "NGC4649_example.fits"
180
- # # elif NGC5813:
181
- # # uploaded_file = "NGC5813_example.fits"
182
 
183
- # # # If file is uploaded, read in the data and plot it
184
- # # if uploaded_file is not None:
185
- # # data, wcs = load_file(uploaded_file)
 
186
 
187
- # # # if "data" not in locals():
188
- # # # data = np.zeros((128,128))
189
 
190
- # # # Make six columns for buttons
191
- # # _, col1, col2, col3, col4, col5, col6, _ = st.columns([bordersize,0.5,0.5,0.5,0.5,0.5,0.5,bordersize])
192
- # # col1.subheader("Input image")
193
- # # col3.subheader("Prediction")
194
- # # col5.subheader("Decomposed")
195
- # # col6.subheader("")
196
-
197
- # # with col1:
198
- # # st.markdown("""<style>[data-baseweb="select"] {margin-top: -46px;}</style>""", unsafe_allow_html=True)
199
- # # max_scale = int(data.shape[0] // 128)
200
- # # scale = st.selectbox('Scale:',[f"{(i+1)*128}x{(i+1)*128}" for i in range(max_scale)], label_visibility="hidden", on_change=change_scale)
201
- # # scale = int(scale.split("x")[0]) // 128
202
-
203
- # # # Detect button
204
- # # with col3: detect = st.button('Detect', key="detect")
 
 
 
 
 
 
205
 
206
- # # # Threshold slider
207
- # # with col4:
208
- # # st.markdown("")
209
- # # # st.markdown("""<style>[data-baseweb="select"] {margin-top: -36px;}</style>""", unsafe_allow_html=True)
210
- # # threshold = st.slider("Threshold", 0.0, 1.0, 0.0, 0.05, key="threshold") #, label_visibility="hidden")
211
-
212
- # # # Decompose button
213
- # # with col5: decompose = st.button('Decompose', key="decompose")
214
-
215
- # # # Make two columns for plots
216
- # # _, colA, colB, colC, _ = st.columns([bordersize,1,1,1,bordersize])
217
 
218
- # # image = np.log10(data+1)
219
- # # plot_image(image, scale)
220
 
221
- # # if detect or threshold:
222
- # # # if st.session_state.get("detect", True):
223
- # # y_pred, wcs = cut_n_predict(data, wcs, scale)
224
-
225
- # # y_pred_th = np.where(y_pred > threshold, y_pred, 0)
226
-
227
- # # plot_prediction(y_pred_th)
228
 
229
- # # if decompose or st.session_state.get("download", False):
230
- # # image_decomposed = decompose_cavity(y_pred_th)
231
 
232
- # # plot_decomposed(image_decomposed)
 
 
 
 
233
 
234
- # # with col6:
235
- # # st.markdown("<br style='margin:4px 0'>", unsafe_allow_html=True)
236
- # # # st.markdown("""<style>[data-baseweb="select"] {margin-top: 16px;}</style>""", unsafe_allow_html=True)
237
- # # fname = uploaded_file.name.strip(".fits")
238
-
239
- # # # if st.session_state.get("download", False):
240
-
241
- # # shutil.make_archive("predictions", 'zip', "predictions")
242
- # # with open('predictions.zip', 'rb') as f:
243
- # # res = f.read()
244
-
245
- # # download = st.download_button(label="Download", data=res, key="download",
246
- # # file_name=f'{fname}_{int(scale*128)}.zip',
247
- # # # disabled=st.session_state.get("disabled", True),
248
- # # mime="application/octet-stream")
 
 
 
 
 
 
 
 
 
1
  # Basic libraries
2
  import os
3
  import shutil
4
  import numpy as np
5
  from scipy.ndimage import center_of_mass
6
  import matplotlib.pyplot as plt
7
+ from matplotlib.colors import Normalize
8
  from matplotlib.patches import Rectangle
9
 
10
  # Astropy
 
20
  # Streamlit
21
  import streamlit as st
22
  st.set_option('deprecation.showPyplotGlobalUse', False)
23
+ st.set_page_config(page_title="Cavity Detection Tool", layout="wide")
24
 
25
+ # HuggingFace Hub
26
+ # os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
27
+ from huggingface_hub import from_pretrained_keras
28
+ # from tensorflow.keras.models import load_model
29
+
30
+
31
+ # Define function to plot the uploaded image
32
+ def plot_image(image, scale):
33
+ plt.figure(figsize=(4, 4))
34
+ x0 = image.shape[0] // 2 - scale * 128 / 2
35
+ plt.imshow(image, origin="lower")
36
+ plt.gca().add_patch(Rectangle((x0-0.5, x0-0.5), scale*128, scale*128, linewidth=1, edgecolor='w', facecolor='none'))
37
+ plt.axis('off')
38
+ plt.tight_layout()
39
+ with colA: st.pyplot()
40
+
41
+ # Define function to plot the prediction
42
+ def plot_prediction(pred):
43
+ plt.figure(figsize=(4, 4))
44
+ plt.imshow(pred, origin="lower", norm=Normalize(vmin=0, vmax=1))
45
+ plt.axis('off')
46
+ with colB: st.pyplot()
47
+
48
+ # Define function to plot the decomposed prediction
49
+ def plot_decomposed(decomposed):
50
+ plt.figure(figsize=(4, 4))
51
+ plt.imshow(decomposed, origin="lower")
52
+
53
+ N = int(np.max(decomposed))
54
+ for i in range(N):
55
+ new = np.where(decomposed == i+1, 1, 0)
56
+ x0, y0 = center_of_mass(new)
57
+ color = "white" if i < N//2 else "black"
58
+ plt.text(y0, x0, f"{i+1}", ha="center", va="center", fontsize=15, color=color)
59
 
60
+ plt.axis('off')
61
+ with colC: st.pyplot()
62
 
63
+ # Define function to cut input image and rebin it to 128x128 pixels
64
+ def cut(data0, wcs0, scale=1):
65
+ shape = data0.shape[0]
66
+ x0 = shape / 2
67
+ size = 128 * scale
68
+ cutout = Cutout2D(data0, (x0, x0), (size, size), wcs=wcs0)
69
+ data, wcs = cutout.data, cutout.wcs
70
+
71
+ # Regrid data
72
+ factor = size // 128
73
+ data = data.reshape(128, factor, 128, factor).mean(-1).mean(1)
74
 
75
+ # Regrid wcs
76
+ ra, dec = wcs.wcs_pix2world(np.array([[63, 63]]),0)[0]
77
+ wcs.wcs.cdelt[0] = wcs.wcs.cdelt[0] * factor
78
+ wcs.wcs.cdelt[1] = wcs.wcs.cdelt[1] * factor
79
+ wcs.wcs.crval[0] = ra
80
+ wcs.wcs.crval[1] = dec
81
+ wcs.wcs.crpix[0] = 64 / factor
82
+ wcs.wcs.crpix[1] = 64 / factor
83
+
84
+ return data, wcs
85
+
86
+ # Define function to apply cutting and produce a prediction
87
+ @st.cache_data
88
+ def cut_n_predict(data, _wcs, scale):
89
+ data, wcs = cut(data, _wcs, scale=scale)
90
+ image = np.log10(data+1)
91
 
92
+ y_pred = 0
93
+ for j in [0,1,2,3]:
94
+ rotated = np.rot90(image, j)
95
+ pred = model.predict(rotated.reshape(1, 128, 128, 1)).reshape(128 ,128)
96
+ pred = np.rot90(pred, -j)
97
+ y_pred += pred / 4
98
+
99
+ return y_pred, wcs
100
+
101
+ # Define function to decompose prediction into individual cavities
102
+ @st.cache_data
103
+ def decompose_cavity(pred, fname, th2=0.7, amin=10):
104
+ X, Y = pred.nonzero()
105
+ data = np.array([X,Y]).reshape(2, -1)
106
+
107
+ # DBSCAN clustering
108
+ try: clusters = DBSCAN(eps=1.0, min_samples=3).fit(data.T).labels_
109
+ except: clusters = []
110
+
111
+ N = len(set(clusters))
112
+ cavities = []
113
+
114
+ for i in range(N):
115
+ img = np.zeros((128,128))
116
+ b = clusters == i
117
+ xi, yi = X[b], Y[b]
118
+ img[xi, yi] = pred[xi, yi]
119
+
120
+ # # Thresholding #2
121
+ # if not (img > th2).any(): continue
122
+
123
+ # Minimal area
124
+ if np.sum(img) <= amin: continue
125
+
126
+ cavities.append(img)
127
+
128
+ # Save raw and decomposed predictions to predictions folder
129
+ ccd = CCDData(pred, unit="adu", wcs=wcs)
130
+ ccd.write(f"{fname}/predicted.fits", overwrite=True)
131
+ image_decomposed = np.zeros((128,128))
132
+ for i, cav in enumerate(cavities):
133
+ ccd = CCDData(cav, unit="adu", wcs=wcs)
134
+ ccd.write(f"{fname}/decomposed_{i+1}.fits", overwrite=True)
135
+ image_decomposed += (i+1) * np.where(cav > 0, 1, 0)
136
+
137
+ # shutil.make_archive("predictions", 'zip', "predictions")
138
 
139
+ return image_decomposed
140
 
141
+ @st.cache_data
142
+ def load_file(fname):
143
+ with fits.open(fname) as hdul:
144
+ data = hdul[0].data
145
+ wcs = WCS(hdul[0].header)
146
+ return data, wcs
147
 
148
+ @st.cache_resource
149
+ def load_CADET():
150
+ model = from_pretrained_keras("Plsek/CADET-v1")
151
+ # model = load_model("CADET.hdf5")
152
+ return model
153
 
154
+ def reset_threshold():
155
+ del st.session_state["threshold"]
156
 
157
 
158
+ # Load model
159
+ model = load_CADET()
 
 
160
 
161
+ # Use wide layout and create columns
162
+ bordersize = 0.6
163
+ _, col, _ = st.columns([bordersize, 3, bordersize])
164
+
165
+ os.system("rm *.zip")
166
+ os.system("rm -R -- */")
167
+ # if os.path.exists("predictions"): os.system("rm -r predictions")
168
  # os.system("mkdir -p predictions")
169
 
170
+ with col:
171
+ # Create heading and description
172
+ st.markdown("<h1 align='center'>Cavity Detection Tool</h1>", unsafe_allow_html=True)
173
+ st.markdown("Cavity Detection Tool (CADET) is a machine learning pipeline trained to detect X-ray cavities from noisy Chandra images of early-type galaxies.")
174
+ st.markdown("To use this tool: upload your image, select the scale of interest, make a prediction, and decompose it into individual cavities!")
175
+ st.markdown("Input images should be in units of counts, centred at the galaxy center, and point sources should be filled with surrounding background ([dmfilth](https://cxc.cfa.harvard.edu/ciao/ahelp/dmfilth.html)).")
176
+ st.markdown("If you use this tool for your research, please cite [Plšek et al. 2023](https://arxiv.org/abs/2304.05457)")
177
 
178
+ # _, col_1, col_2, col_3, _ = st.columns([bordersize, 2.0, 0.5, 0.5, bordersize])
179
 
180
+ # with col:
181
+ uploaded_file = st.file_uploader("Choose a FITS file", type=['fits']) #, on_change=reset_threshold)
182
 
183
+ # with col_2:
184
+ # st.markdown("### Examples")
185
+ # NGC4649 = st.button("NGC4649")
186
 
187
+ # with col_3:
188
+ # st.markdown("""<style>[data-baseweb="select"] {margin-top: 26px;}</style>""", unsafe_allow_html=True)
189
+ # NGC5813 = st.button("NGC5813")
190
 
191
+ # if NGC4649:
192
+ # uploaded_file = "NGC4649_example.fits"
193
+ # elif NGC5813:
194
+ # uploaded_file = "NGC5813_example.fits"
195
 
196
+ # If file is uploaded, read in the data and plot it
197
+ if uploaded_file is not None:
198
+ data, wcs = load_file(uploaded_file)
199
+ os.mkdir(uploaded_file.name.strip(".fits"))
200
 
201
+ if "data" not in locals():
202
+ data = np.zeros((128,128))
203
 
204
+ # Make six columns for buttons
205
+ _, col1, col2, col3, col4, col5, col6, _ = st.columns([bordersize,0.5,0.5,0.5,0.5,0.5,0.5,bordersize])
206
+ col1.subheader("Input image")
207
+ col3.subheader("Prediction")
208
+ col5.subheader("Decomposed")
209
+ col6.subheader("")
210
+
211
+ with col1:
212
+ st.markdown("""<style>[data-baseweb="select"] {margin-top: -46px;}</style>""", unsafe_allow_html=True)
213
+ max_scale = int(data.shape[0] // 128)
214
+ scale = st.selectbox('Scale:',[f"{(i+1)*128}x{(i+1)*128}" for i in range(max_scale)], label_visibility="hidden", on_change=reset_threshold)
215
+ scale = int(scale.split("x")[0]) // 128
216
+
217
+ # Detect button
218
+ with col3: detect = st.button('Detect', key="detect")
219
+
220
+ # Threshold slider
221
+ with col4:
222
+ st.markdown("")
223
+ # st.markdown("""<style>[data-baseweb="select"] {margin-top: -36px;}</style>""", unsafe_allow_html=True)
224
+ threshold = st.slider("Threshold", 0.0, 1.0, 0.0, 0.05, key="threshold") #, label_visibility="hidden")
225
 
226
+ # Decompose button
227
+ with col5: decompose = st.button('Decompose', key="decompose")
 
 
 
 
 
 
 
 
 
228
 
229
+ # Make two columns for plots
230
+ _, colA, colB, colC, _ = st.columns([bordersize,1,1,1,bordersize])
231
 
232
+ if uploaded_file is not None:
233
+ image = np.log10(data+1)
234
+ plot_image(image, scale)
 
 
 
 
235
 
236
+ if detect or threshold or st.session_state.get("decompose", False):
237
+ fname = uploaded_file.name.strip(".fits")
238
 
239
+ y_pred, wcs = cut_n_predict(data, wcs, scale)
240
+
241
+ y_pred_th = np.where(y_pred > threshold, y_pred, 0)
242
+
243
+ plot_prediction(y_pred_th)
244
 
245
+ if decompose or st.session_state.get("download", False):
246
+ image_decomposed = decompose_cavity(y_pred_th, fname)
247
+
248
+ plot_decomposed(image_decomposed)
249
+
250
+ with col6:
251
+ st.markdown("<br style='margin:4px 0'>", unsafe_allow_html=True)
252
+ # st.markdown("""<style>[data-baseweb="select"] {margin-top: 16px;}</style>""", unsafe_allow_html=True)
253
+
254
+ # if st.session_state.get("download", False):
255
+
256
+ shutil.make_archive(fname, 'zip', fname)
257
+ with open(f"{fname}.zip", 'rb') as f:
258
+ res = f.read()
259
+
260
+ download = st.download_button(label="Download", data=res, key="download",
261
+ file_name=f'{fname}_{int(scale*128)}.zip',
262
+ # disabled=st.session_state.get("disabled", True),
263
+ mime="application/octet-stream")