File size: 7,470 Bytes
0ea9e86
57f9dda
 
0ea9e86
7454012
e52c641
8e54123
6414f94
b4f54b4
6414f94
8111624
2680aa6
7454012
 
6414f94
 
3190548
d01a85f
 
7454012
1197a1f
 
 
7454012
 
6414f94
3ca195b
125c6bf
6414f94
 
d01a85f
ddce24a
d01a85f
 
2680aa6
3190548
32d9557
3190548
 
8120054
c1f2126
6414f94
 
32d9557
6414f94
29cd1e9
47ec853
29cd1e9
d0a6846
8111624
7cb9bcf
8111624
 
 
48b10cd
 
8111624
29cd1e9
 
 
7454012
6414f94
d26c581
 
 
 
 
 
7454012
d26c581
 
 
7454012
d26c581
 
 
 
 
 
 
 
44cb863
6414f94
15c5050
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
983a306
 
15c5050
983a306
 
15c5050
 
 
 
e9ef2e7
 
 
 
 
 
 
4514163
e9ef2e7
 
 
 
 
 
 
 
 
 
 
 
 
 
b635f79
6414f94
 
 
 
 
9c67ee7
125c6bf
36e4579
125c6bf
 
58a68f2
0e52fa4
36a5e8a
bb1bb6a
125c6bf
bb17b88
 
0e52fa4
36a5e8a
58a68f2
8e1fb51
36a5e8a
58a68f2
15c5050
bb17b88
a18eaf5
58a68f2
bb17b88
32d9557
00d6ae6
bb1bb6a
00d6ae6
 
0e52fa4
8e54123
3138cd3
 
e5f809c
3138cd3
 
 
 
 
 
 
8e54123
4c80278
8e54123
983a306
8e54123
 
7d21ed6
e9ef2e7
4c80278
7d21ed6
d488b3c
905881e
e9ef2e7
15c5050
7d21ed6
905881e
c5186ae
e9ef2e7
0658391
905881e
e9ef2e7
 
e52c641
29cd1e9
47ec853
e9ef2e7
0658391
e9ef2e7
 
 
0658391
 
29cd1e9
e9ef2e7
48b10cd
7d21ed6
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
# HuggingFace
from huggingface_hub import from_pretrained_keras
model = from_pretrained_keras("Plsek/CADET-v1")

# Basic libraries
import os
# import shutil
import numpy as np
from scipy.ndimage import center_of_mass
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from matplotlib.patches import Rectangle

# Astropy
from astropy.io import fits
from astropy.wcs import WCS
from astropy.nddata import Cutout2D, CCDData
from astropy.convolution import Gaussian2DKernel as Gauss
from astropy.convolution import convolve

# Scikit-learn
from sklearn.cluster import DBSCAN

# Streamlit
import streamlit as st
st.set_option('deprecation.showPyplotGlobalUse', False)
st.set_page_config(page_title="Cavity Detection Tool", layout="wide")
# st.title("Cavity Detection Tool")

# Define function to plot the uploaded image
def plot_image(image, scale):
    plt.figure(figsize=(4, 4))
    x0 = image.shape[0] // 2 - scale * 128 / 2
    plt.imshow(image, origin="lower")
    plt.gca().add_patch(Rectangle((x0, x0), scale*128, scale*128, linewidth=1, edgecolor='w', facecolor='none'))
    plt.axis('off')
    with colA: st.pyplot()

# Define function to plot the prediction
def plot_prediction(pred):
    plt.figure(figsize=(4, 4))
    plt.imshow(pred, origin="lower")
    plt.axis('off')
    with colB: st.pyplot()

# Define function to plot the decomposed prediction
def plot_decomposed(decomposed):
    plt.figure(figsize=(4, 4))
    plt.imshow(decomposed, origin="lower") #, norm=LogNorm())

    N = int(np.max(decomposed))
    for i in range(N):
        new = np.where(decomposed == i+1, 1, 0)
        x0, y0 = center_of_mass(new)
        color = "white" if i < N//2 else "black"
        plt.text(y0, x0, f"{i+1}", ha="center", va="center", fontsize=15, color=color)
    
    plt.axis('off')
    with colC: st.pyplot()
        
# Cut input image and rebin it to 128x128 pixels
def cut(data0, wcs0, scale=1):
    shape = data0.shape[0]
    x0 = shape / 2
    size = 128 * scale
    cutout = Cutout2D(data0, (x0, x0), (size, size), wcs=wcs0)
    data, wcs = cutout.data, cutout.wcs

    # Regrid data
    factor = size // 128
    data = data.reshape(128, factor, 128, factor).mean(-1).mean(1)
    
    # Regrid wcs
    ra, dec = wcs.wcs_pix2world(np.array([[63, 63]]),0)[0]
    wcs.wcs.cdelt[0] = wcs.wcs.cdelt[0] * factor
    wcs.wcs.cdelt[1] = wcs.wcs.cdelt[1] * factor
    wcs.wcs.crval[0] = ra
    wcs.wcs.crval[1] = dec
    wcs.wcs.crpix[0] = 64 / factor
    wcs.wcs.crpix[1] = 64 / factor

    return data, wcs

def decompose_cavity(pred, th2=0.7, amin=10):
    X, Y = pred.nonzero()
    data = np.array([X,Y]).reshape(2, -1)

    # DBSCAN CLUSTERING ALGORITHM
    try: clusters = DBSCAN(eps=1.5, min_samples=3).fit(data.T).labels_
    except: clusters = []

    N = len(set(clusters))
    cavities = []

    for i in range(N):
        img = np.zeros((128,128))
        b = clusters == i
        xi, yi = X[b], Y[b]
        img[xi, yi] = pred[xi, yi]

        # # THRESHOLDING #2
        # if not (img > th2).any(): continue

        # # MINIMAL AREA
        # if np.sum(img) <= amin: continue

        cavities.append(img)

    return cavities




bordersize = 0.6
_, col, _ = st.columns([bordersize, 3, bordersize])

# if os.path.exists("pred.npy"): os.system("rm pred.npy")
# os.system("rm -r predictions")
# os.system("rm predictions.zip Views")
# os.system("mkdir -p predictions")

with col:
    st.markdown("# Cavity Detection Tool")
    
    st.markdown("Cavity Detection Tool (CADET) is a machine learning pipeline trained to detect X-ray cavities from noisy Chandra images of early-type galaxies.")
    st.markdown("To use this tool: upload your image, select the scale of interest, and make a prediction!")
    st.markdown("Input images should be centred at the centre of the galaxy and point sources should be filled with surrounding background ([dmfilth](https://cxc.cfa.harvard.edu/ciao/ahelp/dmfilth.html)).")
    st.markdown("If you use this tool for your research, please cite [Plšek et al. 2023](https://arxiv.org/abs/2304.05457)")
    
    # Create file uploader widget
    uploaded_file = st.file_uploader("Choose a FITS file", type=['fits'])
    
# If file is uploaded, read in the data and plot it
if uploaded_file is not None:
    with fits.open(uploaded_file) as hdul:
        data = hdul[0].data
        wcs = WCS(hdul[0].header)
    
    # Make four columns for buttons
    _, col1, col2, col3, col4, col5, col6, _ = st.columns([bordersize,0.5,0.5,0.5,0.5,0.5,0.5,bordersize])
    col1.subheader("Input image")
    col3.subheader("Prediction")
    col5.subheader("Decomposed")
    
    with col1:
        # st.markdown("""<style>[data-baseweb="select"] {margin-top: -26px;}</style>""", unsafe_allow_html=True)
        max_scale = int(data.shape[0] // 128)
        scale = st.selectbox('Scale:',[f"{(i+1)*128}x{(i+1)*128}" for i in range(max_scale)], label_visibility="hidden")
        scale = int(scale.split("x")[0]) // 128
    
    with col3:
        detect = st.button('Detect')

    with col5:
        decompose = st.button('Docompose')
        
    # Make two columns for plots
    _, colA, colB, colC, _ = st.columns([bordersize,1,1,1,bordersize])
    
    image = np.log10(data+1)
    plot_image(image, scale)

    with col4:
        st.markdown("""<style>[data-baseweb="select"] {margin-top: -36px;}</style>""", unsafe_allow_html=True)
        threshold = st.slider("", 0.0, 1.0, 0.0, 0.05, label_visibility="hidden")
    
    if detect or threshold:
        data, wcs = cut(data, wcs, scale=scale)
        image = np.log10(data+1)
        
        y_pred = 0
        for j in [0,1,2,3]:
            rotated = np.rot90(image, j)
            pred = model.predict(rotated.reshape(1, 128, 128, 1)).reshape(128 ,128)
            pred = np.rot90(pred, -j)
            y_pred += pred / 4

        # np.save("pred.npy", y_pred)

    # try: y_pred = np.load("pred.npy")
    # except: y_pred = np.zeros((128,128))
    try: y_pred
    except: y_pred = np.zeros((128,128))
    y_pred_th = np.where(y_pred > threshold, y_pred, 0)
    # np.save("thresh.npy", y_pred)
        
    plot_prediction(y_pred_th)

    if decompose:
        # y_pred = np.load("thresh.npy")
        
        cavs = decompose_cavity(y_pred_th)

        # ccd = CCDData(y_pred, unit="adu", wcs=wcs)
        # ccd.write(f"predictions/predicted.fits", overwrite=True)
        image_decomposed = np.zeros((128,128))
        for i, cav in enumerate(cavs):
            # ccd = CCDData(cav, unit="adu", wcs=wcs)
            # ccd.write(f"predictions/predicted_{i+1}.fits", overwrite=True)
            image_decomposed += (i+1) * np.where(cav > 0, 1, 0)

        # shutil.make_archive("predictions.zip", 'zip', "predictions")
        # np.save("decomposed.npy", image_decomposed)

    # try: image_decomposed = np.load("decomposed.npy")
    # except: image_decomposed = np.zeros((128,128))
    try: image_decomposed
    except: image_decomposed = np.zeros((128,128))
    plot_decomposed(image_decomposed)

    # shutil.make_archive("predictions", 'zip', "predictions")

    # with col6:
    #     ccd = CCDData(y_pred, unit="adu", wcs=wcs)
    #     # with open('predictions.zip', 'rb') as f:
    #     #     res = f.read()
    #     st.markdown("""<style>[data-baseweb="select"] {margin-top: 16px;}</style>""", unsafe_allow_html=True)
    #     download = st.download_button(label="Download", data=ccd, file_name='prediction.fits', mime="application/octet-stream")