Update app.py
Browse files
app.py
CHANGED
@@ -25,224 +25,224 @@ from sklearn.cluster import DBSCAN
|
|
25 |
import streamlit as st
|
26 |
st.set_option('deprecation.showPyplotGlobalUse', False)
|
27 |
|
28 |
-
# Define function to plot the uploaded image
|
29 |
-
def plot_image(image, scale):
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
# Define function to plot the prediction
|
39 |
-
def plot_prediction(pred):
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
# Define function to plot the decomposed prediction
|
46 |
-
def plot_decomposed(decomposed):
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
|
57 |
-
|
58 |
-
|
59 |
|
60 |
-
# Define function to cut input image and rebin it to 128x128 pixels
|
61 |
-
def cut(data0, wcs0, scale=1):
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
# Define function to apply cutting and produce a prediction
|
84 |
-
# @st.cache
|
85 |
-
def cut_n_predict(data, wcs, scale):
|
86 |
-
|
87 |
-
|
88 |
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
# Define function to decompose prediction into individual cavities
|
99 |
-
# @st.cache
|
100 |
-
def decompose_cavity(pred, th2=0.7, amin=6):
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
|
136 |
-
|
137 |
|
138 |
-
# @st.cache
|
139 |
-
def load_file(fname):
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
|
145 |
-
def change_scale():
|
146 |
-
|
147 |
|
148 |
|
149 |
|
150 |
-
# Use wide layout and create columns
|
151 |
-
st.set_page_config(page_title="Cavity Detection Tool", layout="wide")
|
152 |
-
bordersize = 0.45
|
153 |
-
_, col, _ = st.columns([bordersize, 3, bordersize])
|
154 |
|
155 |
-
os.system("mkdir -p predictions")
|
156 |
|
157 |
-
with col:
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
|
165 |
-
# _, col_1, col_2, col_3, _ = st.columns([bordersize, 2.0, 0.5, 0.5, bordersize])
|
166 |
|
167 |
-
# with col:
|
168 |
-
|
169 |
|
170 |
-
# with col_2:
|
171 |
-
# st.markdown("### Examples")
|
172 |
-
# NGC4649 = st.button("NGC4649")
|
173 |
|
174 |
-
# with col_3:
|
175 |
-
# st.markdown("""<style>[data-baseweb="select"] {margin-top: 26px;}</style>""", unsafe_allow_html=True)
|
176 |
-
# NGC5813 = st.button("NGC5813")
|
177 |
|
178 |
-
# if NGC4649:
|
179 |
-
# uploaded_file = "NGC4649_example.fits"
|
180 |
-
# elif NGC5813:
|
181 |
-
# uploaded_file = "NGC5813_example.fits"
|
182 |
|
183 |
-
# # If file is uploaded, read in the data and plot it
|
184 |
-
# if uploaded_file is not None:
|
185 |
-
# data, wcs = load_file(uploaded_file)
|
186 |
|
187 |
-
# # if "data" not in locals():
|
188 |
-
# # data = np.zeros((128,128))
|
189 |
|
190 |
-
# # Make six columns for buttons
|
191 |
-
# _, col1, col2, col3, col4, col5, col6, _ = st.columns([bordersize,0.5,0.5,0.5,0.5,0.5,0.5,bordersize])
|
192 |
-
# col1.subheader("Input image")
|
193 |
-
# col3.subheader("Prediction")
|
194 |
-
# col5.subheader("Decomposed")
|
195 |
-
# col6.subheader("")
|
196 |
|
197 |
-
# with col1:
|
198 |
-
# st.markdown("""<style>[data-baseweb="select"] {margin-top: -46px;}</style>""", unsafe_allow_html=True)
|
199 |
-
# max_scale = int(data.shape[0] // 128)
|
200 |
-
# scale = st.selectbox('Scale:',[f"{(i+1)*128}x{(i+1)*128}" for i in range(max_scale)], label_visibility="hidden", on_change=change_scale)
|
201 |
-
# scale = int(scale.split("x")[0]) // 128
|
202 |
|
203 |
-
# # Detect button
|
204 |
-
# with col3: detect = st.button('Detect', key="detect")
|
205 |
|
206 |
-
# # Threshold slider
|
207 |
-
# with col4:
|
208 |
-
# st.markdown("")
|
209 |
-
# # st.markdown("""<style>[data-baseweb="select"] {margin-top: -36px;}</style>""", unsafe_allow_html=True)
|
210 |
-
# threshold = st.slider("Threshold", 0.0, 1.0, 0.0, 0.05, key="threshold") #, label_visibility="hidden")
|
211 |
|
212 |
-
# # Decompose button
|
213 |
-
# with col5: decompose = st.button('Decompose', key="decompose")
|
214 |
|
215 |
-
# # Make two columns for plots
|
216 |
-
# _, colA, colB, colC, _ = st.columns([bordersize,1,1,1,bordersize])
|
217 |
|
218 |
-
# image = np.log10(data+1)
|
219 |
-
# plot_image(image, scale)
|
220 |
|
221 |
-
# if detect or threshold:
|
222 |
-
# # if st.session_state.get("detect", True):
|
223 |
-
# y_pred, wcs = cut_n_predict(data, wcs, scale)
|
224 |
|
225 |
-
# y_pred_th = np.where(y_pred > threshold, y_pred, 0)
|
226 |
|
227 |
-
# plot_prediction(y_pred_th)
|
228 |
|
229 |
-
# if decompose or st.session_state.get("download", False):
|
230 |
-
# image_decomposed = decompose_cavity(y_pred_th)
|
231 |
|
232 |
-
# plot_decomposed(image_decomposed)
|
233 |
|
234 |
-
# with col6:
|
235 |
-
# st.markdown("<br style='margin:4px 0'>", unsafe_allow_html=True)
|
236 |
-
# # st.markdown("""<style>[data-baseweb="select"] {margin-top: 16px;}</style>""", unsafe_allow_html=True)
|
237 |
-
# fname = uploaded_file.name.strip(".fits")
|
238 |
|
239 |
-
# # if st.session_state.get("download", False):
|
240 |
|
241 |
-
# shutil.make_archive("predictions", 'zip', "predictions")
|
242 |
-
# with open('predictions.zip', 'rb') as f:
|
243 |
-
# res = f.read()
|
244 |
|
245 |
-
# download = st.download_button(label="Download", data=res, key="download",
|
246 |
-
# file_name=f'{fname}_{int(scale*128)}.zip',
|
247 |
-
# # disabled=st.session_state.get("disabled", True),
|
248 |
-
# mime="application/octet-stream")
|
|
|
25 |
import streamlit as st
|
26 |
st.set_option('deprecation.showPyplotGlobalUse', False)
|
27 |
|
28 |
+
# # Define function to plot the uploaded image
|
29 |
+
# def plot_image(image, scale):
|
30 |
+
# plt.figure(figsize=(4, 4))
|
31 |
+
# x0 = image.shape[0] // 2 - scale * 128 / 2
|
32 |
+
# plt.imshow(image, origin="lower")
|
33 |
+
# plt.gca().add_patch(Rectangle((x0-0.5, x0-0.5), scale*128, scale*128, linewidth=1, edgecolor='w', facecolor='none'))
|
34 |
+
# plt.axis('off')
|
35 |
+
# plt.tight_layout()
|
36 |
+
# with colA: st.pyplot()
|
37 |
+
|
38 |
+
# # Define function to plot the prediction
|
39 |
+
# def plot_prediction(pred):
|
40 |
+
# plt.figure(figsize=(4, 4))
|
41 |
+
# plt.imshow(pred, origin="lower")
|
42 |
+
# plt.axis('off')
|
43 |
+
# with colB: st.pyplot()
|
44 |
+
|
45 |
+
# # Define function to plot the decomposed prediction
|
46 |
+
# def plot_decomposed(decomposed):
|
47 |
+
# plt.figure(figsize=(4, 4))
|
48 |
+
# plt.imshow(decomposed, origin="lower") #, norm=LogNorm())
|
49 |
+
|
50 |
+
# N = int(np.max(decomposed))
|
51 |
+
# for i in range(N):
|
52 |
+
# new = np.where(decomposed == i+1, 1, 0)
|
53 |
+
# x0, y0 = center_of_mass(new)
|
54 |
+
# color = "white" if i < N//2 else "black"
|
55 |
+
# plt.text(y0, x0, f"{i+1}", ha="center", va="center", fontsize=15, color=color)
|
56 |
|
57 |
+
# plt.axis('off')
|
58 |
+
# with colC: st.pyplot()
|
59 |
|
60 |
+
# # Define function to cut input image and rebin it to 128x128 pixels
|
61 |
+
# def cut(data0, wcs0, scale=1):
|
62 |
+
# shape = data0.shape[0]
|
63 |
+
# x0 = shape / 2
|
64 |
+
# size = 128 * scale
|
65 |
+
# cutout = Cutout2D(data0, (x0, x0), (size, size), wcs=wcs0)
|
66 |
+
# data, wcs = cutout.data, cutout.wcs
|
67 |
+
|
68 |
+
# # Regrid data
|
69 |
+
# factor = size // 128
|
70 |
+
# data = data.reshape(128, factor, 128, factor).mean(-1).mean(1)
|
71 |
|
72 |
+
# # Regrid wcs
|
73 |
+
# ra, dec = wcs.wcs_pix2world(np.array([[63, 63]]),0)[0]
|
74 |
+
# wcs.wcs.cdelt[0] = wcs.wcs.cdelt[0] * factor
|
75 |
+
# wcs.wcs.cdelt[1] = wcs.wcs.cdelt[1] * factor
|
76 |
+
# wcs.wcs.crval[0] = ra
|
77 |
+
# wcs.wcs.crval[1] = dec
|
78 |
+
# wcs.wcs.crpix[0] = 64 / factor
|
79 |
+
# wcs.wcs.crpix[1] = 64 / factor
|
80 |
+
|
81 |
+
# return data, wcs
|
82 |
+
|
83 |
+
# # Define function to apply cutting and produce a prediction
|
84 |
+
# # @st.cache
|
85 |
+
# def cut_n_predict(data, wcs, scale):
|
86 |
+
# data, wcs = cut(data, wcs, scale=scale)
|
87 |
+
# image = np.log10(data+1)
|
88 |
|
89 |
+
# y_pred = 0
|
90 |
+
# for j in [0,1,2,3]:
|
91 |
+
# rotated = np.rot90(image, j)
|
92 |
+
# pred = model.predict(rotated.reshape(1, 128, 128, 1)).reshape(128 ,128)
|
93 |
+
# pred = np.rot90(pred, -j)
|
94 |
+
# y_pred += pred / 4
|
95 |
+
|
96 |
+
# return y_pred, wcs
|
97 |
+
|
98 |
+
# # Define function to decompose prediction into individual cavities
|
99 |
+
# # @st.cache
|
100 |
+
# def decompose_cavity(pred, th2=0.7, amin=6):
|
101 |
+
# X, Y = pred.nonzero()
|
102 |
+
# data = np.array([X,Y]).reshape(2, -1)
|
103 |
+
|
104 |
+
# # DBSCAN clustering
|
105 |
+
# try: clusters = DBSCAN(eps=1.0, min_samples=3).fit(data.T).labels_
|
106 |
+
# except: clusters = []
|
107 |
+
|
108 |
+
# N = len(set(clusters))
|
109 |
+
# cavities = []
|
110 |
+
|
111 |
+
# for i in range(N):
|
112 |
+
# img = np.zeros((128,128))
|
113 |
+
# b = clusters == i
|
114 |
+
# xi, yi = X[b], Y[b]
|
115 |
+
# img[xi, yi] = pred[xi, yi]
|
116 |
+
|
117 |
+
# # # Thresholding #2
|
118 |
+
# # if not (img > th2).any(): continue
|
119 |
+
|
120 |
+
# # Minimal area
|
121 |
+
# if np.sum(img) <= amin: continue
|
122 |
+
|
123 |
+
# cavities.append(img)
|
124 |
+
|
125 |
+
# # Save raw and decomposed predictions to predictions folder
|
126 |
+
# ccd = CCDData(pred, unit="adu", wcs=wcs)
|
127 |
+
# ccd.write(f"predictions/predicted.fits", overwrite=True)
|
128 |
+
# image_decomposed = np.zeros((128,128))
|
129 |
+
# for i, cav in enumerate(cavities):
|
130 |
+
# ccd = CCDData(cav, unit="adu", wcs=wcs)
|
131 |
+
# ccd.write(f"predictions/predicted_{i+1}.fits", overwrite=True)
|
132 |
+
# image_decomposed += (i+1) * np.where(cav > 0, 1, 0)
|
133 |
+
|
134 |
+
# # shutil.make_archive("predictions", 'zip', "predictions")
|
135 |
|
136 |
+
# return image_decomposed
|
137 |
|
138 |
+
# # @st.cache
|
139 |
+
# def load_file(fname):
|
140 |
+
# with fits.open(fname) as hdul:
|
141 |
+
# data = hdul[0].data
|
142 |
+
# wcs = WCS(hdul[0].header)
|
143 |
+
# return data, wcs
|
144 |
|
145 |
+
# def change_scale():
|
146 |
+
# del st.session_state["threshold"]
|
147 |
|
148 |
|
149 |
|
150 |
+
# # Use wide layout and create columns
|
151 |
+
# st.set_page_config(page_title="Cavity Detection Tool", layout="wide")
|
152 |
+
# bordersize = 0.45
|
153 |
+
# _, col, _ = st.columns([bordersize, 3, bordersize])
|
154 |
|
155 |
+
# os.system("mkdir -p predictions")
|
156 |
|
157 |
+
# with col:
|
158 |
+
# # Create heading and description
|
159 |
+
# st.markdown("<h1 align='center'>Cavity Detection Tool</h1>", unsafe_allow_html=True)
|
160 |
+
# st.markdown("Cavity Detection Tool (CADET) is a machine learning pipeline trained to detect X-ray cavities from noisy Chandra images of early-type galaxies.")
|
161 |
+
# st.markdown("To use this tool: upload your image, select the scale of interest, make a prediction, and decompose it into individual cavities!")
|
162 |
+
# st.markdown("Input images should be in units of counts, centred at the galaxy center, and point sources should be filled with surrounding background ([dmfilth](https://cxc.cfa.harvard.edu/ciao/ahelp/dmfilth.html)).")
|
163 |
+
# st.markdown("If you use this tool for your research, please cite [Plšek et al. 2023](https://arxiv.org/abs/2304.05457)")
|
164 |
|
165 |
+
# # _, col_1, col_2, col_3, _ = st.columns([bordersize, 2.0, 0.5, 0.5, bordersize])
|
166 |
|
167 |
+
# # with col:
|
168 |
+
# uploaded_file = st.file_uploader("Choose a FITS file", type=['fits'])
|
169 |
|
170 |
+
# # with col_2:
|
171 |
+
# # st.markdown("### Examples")
|
172 |
+
# # NGC4649 = st.button("NGC4649")
|
173 |
|
174 |
+
# # with col_3:
|
175 |
+
# # st.markdown("""<style>[data-baseweb="select"] {margin-top: 26px;}</style>""", unsafe_allow_html=True)
|
176 |
+
# # NGC5813 = st.button("NGC5813")
|
177 |
|
178 |
+
# # if NGC4649:
|
179 |
+
# # uploaded_file = "NGC4649_example.fits"
|
180 |
+
# # elif NGC5813:
|
181 |
+
# # uploaded_file = "NGC5813_example.fits"
|
182 |
|
183 |
+
# # # If file is uploaded, read in the data and plot it
|
184 |
+
# # if uploaded_file is not None:
|
185 |
+
# # data, wcs = load_file(uploaded_file)
|
186 |
|
187 |
+
# # # if "data" not in locals():
|
188 |
+
# # # data = np.zeros((128,128))
|
189 |
|
190 |
+
# # # Make six columns for buttons
|
191 |
+
# # _, col1, col2, col3, col4, col5, col6, _ = st.columns([bordersize,0.5,0.5,0.5,0.5,0.5,0.5,bordersize])
|
192 |
+
# # col1.subheader("Input image")
|
193 |
+
# # col3.subheader("Prediction")
|
194 |
+
# # col5.subheader("Decomposed")
|
195 |
+
# # col6.subheader("")
|
196 |
|
197 |
+
# # with col1:
|
198 |
+
# # st.markdown("""<style>[data-baseweb="select"] {margin-top: -46px;}</style>""", unsafe_allow_html=True)
|
199 |
+
# # max_scale = int(data.shape[0] // 128)
|
200 |
+
# # scale = st.selectbox('Scale:',[f"{(i+1)*128}x{(i+1)*128}" for i in range(max_scale)], label_visibility="hidden", on_change=change_scale)
|
201 |
+
# # scale = int(scale.split("x")[0]) // 128
|
202 |
|
203 |
+
# # # Detect button
|
204 |
+
# # with col3: detect = st.button('Detect', key="detect")
|
205 |
|
206 |
+
# # # Threshold slider
|
207 |
+
# # with col4:
|
208 |
+
# # st.markdown("")
|
209 |
+
# # # st.markdown("""<style>[data-baseweb="select"] {margin-top: -36px;}</style>""", unsafe_allow_html=True)
|
210 |
+
# # threshold = st.slider("Threshold", 0.0, 1.0, 0.0, 0.05, key="threshold") #, label_visibility="hidden")
|
211 |
|
212 |
+
# # # Decompose button
|
213 |
+
# # with col5: decompose = st.button('Decompose', key="decompose")
|
214 |
|
215 |
+
# # # Make two columns for plots
|
216 |
+
# # _, colA, colB, colC, _ = st.columns([bordersize,1,1,1,bordersize])
|
217 |
|
218 |
+
# # image = np.log10(data+1)
|
219 |
+
# # plot_image(image, scale)
|
220 |
|
221 |
+
# # if detect or threshold:
|
222 |
+
# # # if st.session_state.get("detect", True):
|
223 |
+
# # y_pred, wcs = cut_n_predict(data, wcs, scale)
|
224 |
|
225 |
+
# # y_pred_th = np.where(y_pred > threshold, y_pred, 0)
|
226 |
|
227 |
+
# # plot_prediction(y_pred_th)
|
228 |
|
229 |
+
# # if decompose or st.session_state.get("download", False):
|
230 |
+
# # image_decomposed = decompose_cavity(y_pred_th)
|
231 |
|
232 |
+
# # plot_decomposed(image_decomposed)
|
233 |
|
234 |
+
# # with col6:
|
235 |
+
# # st.markdown("<br style='margin:4px 0'>", unsafe_allow_html=True)
|
236 |
+
# # # st.markdown("""<style>[data-baseweb="select"] {margin-top: 16px;}</style>""", unsafe_allow_html=True)
|
237 |
+
# # fname = uploaded_file.name.strip(".fits")
|
238 |
|
239 |
+
# # # if st.session_state.get("download", False):
|
240 |
|
241 |
+
# # shutil.make_archive("predictions", 'zip', "predictions")
|
242 |
+
# # with open('predictions.zip', 'rb') as f:
|
243 |
+
# # res = f.read()
|
244 |
|
245 |
+
# # download = st.download_button(label="Download", data=res, key="download",
|
246 |
+
# # file_name=f'{fname}_{int(scale*128)}.zip',
|
247 |
+
# # # disabled=st.session_state.get("disabled", True),
|
248 |
+
# # mime="application/octet-stream")
|