File size: 8,113 Bytes
0ea9e86
57f9dda
 
0ea9e86
7454012
e52c641
2bb6d79
6414f94
b4f54b4
6414f94
8111624
2680aa6
7454012
 
6414f94
 
3190548
d01a85f
 
7454012
1197a1f
 
 
7454012
 
6414f94
3ca195b
125c6bf
6414f94
 
d01a85f
ddce24a
d01a85f
 
2680aa6
3190548
32d9557
3190548
 
8120054
c1f2126
6414f94
 
32d9557
6414f94
29cd1e9
47ec853
29cd1e9
d0a6846
8111624
7cb9bcf
8111624
 
 
48b10cd
 
8111624
29cd1e9
 
 
7454012
6414f94
d26c581
 
 
 
 
 
7454012
d26c581
 
 
7454012
d26c581
 
 
 
 
 
 
 
44cb863
6414f94
a8b8828
5d6cf3b
9219e24
a8b8828
 
 
 
 
 
 
 
 
 
4161769
a8b8828
 
38ebf86
15c5050
 
 
 
 
7e4e8dd
15c5050
 
 
 
 
 
 
 
 
 
 
7c32f83
 
15c5050
7c32f83
 
15c5050
 
 
393d78f
 
 
 
 
 
 
 
 
e9ef2e7
bcc9b96
 
 
 
 
 
e9ef2e7
 
 
 
4514163
e9ef2e7
 
9d7e4e9
e9ef2e7
 
 
 
 
 
 
 
 
 
 
b635f79
6414f94
 
 
 
 
9c67ee7
8516575
36e4579
125c6bf
 
58a68f2
7e4e8dd
0e52fa4
36a5e8a
925e662
bb1bb6a
125c6bf
cc6eac9
bb17b88
0e52fa4
36a5e8a
58a68f2
8e1fb51
36a5e8a
3bee343
15c5050
bb17b88
a18eaf5
58a68f2
bb17b88
32d9557
7bcd4a9
a8b8828
 
7bcd4a9
cc6eac9
a8b8828
e3ab055
4161769
7bcd4a9
8e54123
7bcd4a9
9219e24
 
3bee343
9219e24
 
 
 
 
 
 
 
7bcd4a9
393d78f
9219e24
393d78f
 
 
 
 
 
 
 
 
 
 
 
60d64ba
 
 
 
7c32f83
60d64ba
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
# HuggingFace
from huggingface_hub import from_pretrained_keras
model = from_pretrained_keras("Plsek/CADET-v1")

# Basic libraries
import os
import shutil
import numpy as np
from scipy.ndimage import center_of_mass
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from matplotlib.patches import Rectangle

# Astropy
from astropy.io import fits
from astropy.wcs import WCS
from astropy.nddata import Cutout2D, CCDData
from astropy.convolution import Gaussian2DKernel as Gauss
from astropy.convolution import convolve

# Scikit-learn
from sklearn.cluster import DBSCAN

# Streamlit
import streamlit as st
st.set_option('deprecation.showPyplotGlobalUse', False)
st.set_page_config(page_title="Cavity Detection Tool", layout="wide")
# st.title("Cavity Detection Tool")

# Define function to plot the uploaded image
def plot_image(image, scale):
    plt.figure(figsize=(4, 4))
    x0 = image.shape[0] // 2 - scale * 128 / 2
    plt.imshow(image, origin="lower")
    plt.gca().add_patch(Rectangle((x0, x0), scale*128, scale*128, linewidth=1, edgecolor='w', facecolor='none'))
    plt.axis('off')
    with colA: st.pyplot()

# Define function to plot the prediction
def plot_prediction(pred):
    plt.figure(figsize=(4, 4))
    plt.imshow(pred, origin="lower")
    plt.axis('off')
    with colB: st.pyplot()

# Define function to plot the decomposed prediction
def plot_decomposed(decomposed):
    plt.figure(figsize=(4, 4))
    plt.imshow(decomposed, origin="lower") #, norm=LogNorm())

    N = int(np.max(decomposed))
    for i in range(N):
        new = np.where(decomposed == i+1, 1, 0)
        x0, y0 = center_of_mass(new)
        color = "white" if i < N//2 else "black"
        plt.text(y0, x0, f"{i+1}", ha="center", va="center", fontsize=15, color=color)
    
    plt.axis('off')
    with colC: st.pyplot()
        
# Cut input image and rebin it to 128x128 pixels
def cut(data0, wcs0, scale=1):
    shape = data0.shape[0]
    x0 = shape / 2
    size = 128 * scale
    cutout = Cutout2D(data0, (x0, x0), (size, size), wcs=wcs0)
    data, wcs = cutout.data, cutout.wcs

    # Regrid data
    factor = size // 128
    data = data.reshape(128, factor, 128, factor).mean(-1).mean(1)
    
    # Regrid wcs
    ra, dec = wcs.wcs_pix2world(np.array([[63, 63]]),0)[0]
    wcs.wcs.cdelt[0] = wcs.wcs.cdelt[0] * factor
    wcs.wcs.cdelt[1] = wcs.wcs.cdelt[1] * factor
    wcs.wcs.crval[0] = ra
    wcs.wcs.crval[1] = dec
    wcs.wcs.crpix[0] = 64 / factor
    wcs.wcs.crpix[1] = 64 / factor

    return data, wcs


@st.cache
def cut_n_predict(data, wcs, scale):
    data, wcs = cut(data, wcs, scale=scale)
    image = np.log10(data+1)
    
    y_pred = 0
    for j in [0,1,2,3]:
        rotated = np.rot90(image, j)
        pred = model.predict(rotated.reshape(1, 128, 128, 1)).reshape(128 ,128)
        pred = np.rot90(pred, -j)
        y_pred += pred / 4

    return y_pred, wcs

# Define function to decomposed prediction into cavities
@st.cache
def decompose_cavity(pred, th2=0.7, amin=10):
    X, Y = pred.nonzero()
    data = np.array([X,Y]).reshape(2, -1)

    # DBSCAN CLUSTERING ALGORITHM
    try: clusters = DBSCAN(eps=1.0, min_samples=3).fit(data.T).labels_
    except: clusters = []

    N = len(set(clusters))
    cavities = []

    for i in range(N):
        img = np.zeros((128,128))
        b = clusters == i
        xi, yi = X[b], Y[b]
        img[xi, yi] = pred[xi, yi]

        # THRESHOLDING #2
        if not (img > th2).any(): continue

        # MINIMAL AREA
        if np.sum(img) <= amin: continue

        cavities.append(img)

    ccd = CCDData(pred, unit="adu", wcs=wcs)
    ccd.write(f"predictions/predicted.fits", overwrite=True)
    image_decomposed = np.zeros((128,128))
    for i, cav in enumerate(cavities):
        ccd = CCDData(cav, unit="adu", wcs=wcs)
        ccd.write(f"predictions/predicted_{i+1}.fits", overwrite=True)
        image_decomposed += (i+1) * np.where(cav > 0, 1, 0)

    return image_decomposed

# @st.cache
# def zip_predictions():
#     shutil.make_archive("predictions.zip", 'zip', "predictions")
#     with open('predictions.zip', 'rb') as f:
#         res = f.read()
#         return res

bordersize = 0.6
_, col, _ = st.columns([bordersize, 3, bordersize])

# if os.path.exists("pred.npy"): os.system("rm pred.npy")
# os.system("rm -r predictions")
# os.system("rm predictions.zip Views")
os.system("mkdir -p predictions")

with col:
    st.markdown("# Cavity Detection Tool")
    
    st.markdown("Cavity Detection Tool (CADET) is a machine learning pipeline trained to detect X-ray cavities from noisy Chandra images of early-type galaxies.")
    st.markdown("To use this tool: upload your image, select the scale of interest, and make a prediction!")
    st.markdown("Input images should be centred at the centre of the galaxy and point sources should be filled with surrounding background ([dmfilth](https://cxc.cfa.harvard.edu/ciao/ahelp/dmfilth.html)).")
    st.markdown("If you use this tool for your research, please cite [Plšek et al. 2023](https://arxiv.org/abs/2304.05457)")
    
    # Create file uploader widget
    uploaded_file = st.file_uploader("Choose a FITS file", type=['fits'])
    
# If file is uploaded, read in the data and plot it
if uploaded_file is not None:
    with fits.open(uploaded_file) as hdul:
        data = hdul[0].data
        wcs = WCS(hdul[0].header)
    
    # Make six columns for buttons
    _, col1, col2, col3, col4, col5, col6, _ = st.columns([bordersize,0.5,0.5,0.5,0.5,0.5,0.5,bordersize])
    col1.subheader("Input image")
    col3.subheader("Prediction")
    col5.subheader("Decomposed")
    col6.subheader("")
    
    with col1:
        
        # st.markdown("""<style>[data-baseweb="select"] {margin-top: -26px;}</style>""", unsafe_allow_html=True)
        max_scale = int(data.shape[0] // 128)
        scale = st.selectbox('Scale:',[f"{(i+1)*128}x{(i+1)*128}" for i in range(max_scale)], label_visibility="hidden")
        scale = int(scale.split("x")[0]) // 128
    
    with col3:
        detect = st.button('Detect')

    with col5:
        decompose = st.button('Decompose')
        
    # Make two columns for plots
    _, colA, colB, colC, _ = st.columns([bordersize,1,1,1,bordersize])
    
    image = np.log10(data+1)
    plot_image(image, scale)


    with col4:
            st.markdown("""<style>[data-baseweb="select"] {margin-top: -36px;}</style>""", unsafe_allow_html=True)
            threshold = st.slider("", 0.0, 1.0, 0.0, 0.05, label_visibility="hidden")
    
    if detect or threshold:
        y_pred, wcs = cut_n_predict(data, wcs, scale)
        
        # np.save("pred.npy", y_pred)
    
        # try: y_pred = np.load("pred.npy")
        # except: y_pred = np.zeros((128,128))
        try: _ = y_pred
        except: y_pred = np.zeros((128,128))
        y_pred_th = np.where(y_pred > threshold, y_pred, 0)
        # np.save("thresh.npy", y_pred)
                
        plot_prediction(y_pred_th)
        
        if decompose:
            # y_pred = np.load("thresh.npy")
            
            image_decomposed = decompose_cavity(y_pred_th)
    
            # ccd = CCDData(y_pred, unit="adu", wcs=wcs)
            # ccd.write(f"predictions/predicted.fits", overwrite=True)
            # image_decomposed = np.zeros((128,128))
            # for i, cav in enumerate(cavs):
            #     ccd = CCDData(cav, unit="adu", wcs=wcs)
            #     ccd.write(f"predictions/predicted_{i+1}.fits", overwrite=True)
            #     image_decomposed += (i+1) * np.where(cav > 0, 1, 0)

            try: _ = image_decomposed
            except: image_decomposed = np.zeros((128,128))
            plot_decomposed(image_decomposed)
        
    with col6:
        shutil.make_archive("predictions", 'zip', "predictions")
        with open('predictions.zip', 'rb') as f:
            res = f.read()
        st.markdown("")
        # st.markdown("""<style>[data-baseweb="select"] {margin-top: 16px;}</style>""", unsafe_allow_html=True)
        download = st.download_button(label="Download", data=res, file_name='prediction.zip', mime="application/octet-stream")