File size: 6,783 Bytes
0ea9e86
57f9dda
 
0ea9e86
7454012
e52c641
 
6414f94
 
2680aa6
7454012
 
6414f94
 
3190548
d01a85f
 
7454012
1197a1f
 
 
7454012
 
6414f94
3ca195b
125c6bf
6414f94
574acaf
a18eaf5
f27ae54
e52c641
 
9bf2a3b
125c6bf
 
e52c641
 
 
9bf2a3b
 
 
a24616f
c1f2126
6414f94
d01a85f
ddce24a
c1f2126
d01a85f
 
2680aa6
c1f2126
3190548
32d9557
3190548
 
8120054
c1f2126
6414f94
 
32d9557
6414f94
29cd1e9
91f080b
29cd1e9
 
 
 
 
7454012
6414f94
d26c581
 
 
 
 
 
7454012
d26c581
 
 
7454012
d26c581
 
 
 
 
 
 
 
44cb863
6414f94
15c5050
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b635f79
6414f94
 
 
 
 
e04bf99
9c67ee7
125c6bf
36e4579
125c6bf
 
58a68f2
0e52fa4
36a5e8a
bb1bb6a
125c6bf
bb17b88
 
d559933
0e52fa4
36a5e8a
58a68f2
8e1fb51
36a5e8a
58a68f2
15c5050
bb17b88
a18eaf5
58a68f2
bb17b88
32d9557
00d6ae6
bb1bb6a
00d6ae6
 
0e52fa4
bb1bb6a
3138cd3
 
e5f809c
3138cd3
 
 
 
 
 
 
fce8ed0
4c80278
29cd1e9
 
7003431
905881e
4c80278
7003431
d488b3c
1197a1f
 
f94b0b4
905881e
0658391
15c5050
905881e
 
 
 
0658391
905881e
 
 
e52c641
29cd1e9
e52c641
2331f7f
0658391
 
 
 
29cd1e9
e52c641
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
# HuggingFace
from huggingface_hub import from_pretrained_keras
model = from_pretrained_keras("Plsek/CADET-v1")

# Basic libraries
import os
import shutil
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

# Astropy
from astropy.io import fits
from astropy.wcs import WCS
from astropy.nddata import Cutout2D, CCDData
from astropy.convolution import Gaussian2DKernel as Gauss
from astropy.convolution import convolve

# Scikit-learn
from sklearn.cluster import DBSCAN

# Streamlit
import streamlit as st
st.set_option('deprecation.showPyplotGlobalUse', False)
st.set_page_config(page_title="Cavity Detection Tool", layout="wide")
# st.title("Cavity Detection Tool")

bordersize = 0.6
_, col, _ = st.columns([bordersize, 3, bordersize])

os.system("mkdir predictions")

with col:
    st.markdown("# Cavity Detection Tool")
    
    st.markdown("Cavity Detection Tool (CADET) is a machine learning pipeline trained to detect X-ray cavities from noisy Chandra images of early-type galaxies.")
    st.markdown("To use this tool: upload your image, select the scale of interest, and make a prediction!")
    st.markdown("If you use this tool for your research, please cite [Plšek et al. 2023](https://arxiv.org/abs/2304.05457)")
    st.markdown("Input images should be centered at the centre of the galaxy and point sources should be filled with surrounding background ([dmfilth](https://cxc.cfa.harvard.edu/ciao/ahelp/dmfilth.html)).")
    
    # Create file uploader widget
    uploaded_file = st.file_uploader("Choose a FITS file", type=['fits'])

# Define function to plot the uploaded image
def plot_image(image, scale):
    plt.figure(figsize=(4, 4))

    x0 = image.shape[0] // 2 - scale * 128 / 2
    plt.imshow(image, origin="lower")
    plt.gca().add_patch(Rectangle((x0, x0), scale*128, scale*128, linewidth=1, edgecolor='w', facecolor='none'))
    
    plt.axis('off')
    with colA: st.pyplot()

# Define function to plot the prediction
def plot_prediction(pred):
    plt.figure(figsize=(4, 4))
    plt.imshow(pred, origin="lower")
    plt.axis('off')
    with colB: st.pyplot()

# Define function to plot the decomposed prediction
def plot_decomposed(pred):
    plt.figure(figsize=(4, 4))
    plt.imshow(pred, origin="lower")
    plt.axis('off')
    with colC: st.pyplot()
        
# Cut input image and rebin it to 128x128 pixels
def cut(data0, wcs0, scale=1):
    shape = data0.shape[0]
    x0 = shape / 2
    size = 128 * scale
    cutout = Cutout2D(data0, (x0, x0), (size, size), wcs=wcs0)
    data, wcs = cutout.data, cutout.wcs

    # Regrid data
    factor = size // 128
    data = data.reshape(128, factor, 128, factor).mean(-1).mean(1)
    
    # Regrid wcs
    ra, dec = wcs.wcs_pix2world(np.array([[63, 63]]),0)[0]
    wcs.wcs.cdelt[0] = wcs.wcs.cdelt[0] * factor
    wcs.wcs.cdelt[1] = wcs.wcs.cdelt[1] * factor
    wcs.wcs.crval[0] = ra
    wcs.wcs.crval[1] = dec
    wcs.wcs.crpix[0] = 64 / factor
    wcs.wcs.crpix[1] = 64 / factor

    return data, wcs

def decompose_cavity(pred, th2=0.7, amin=10):
    X, Y = pred.nonzero()
    data = np.array([X,Y]).reshape(2, -1)

    # DBSCAN CLUSTERING ALGORITHM
    try: clusters = DBSCAN(eps=1.5, min_samples=3).fit(data.T).labels_
    except: clusters = []

    N = len(set(clusters))
    cavities = []

    for i in range(N):
        img = np.zeros((128,128))
        b = clusters == i
        xi, yi = X[b], Y[b]
        img[xi, yi] = pred[xi, yi]

        # THRESHOLDING #2
        if not (img > th2).any(): continue

        # MINIMAL AREA
        if np.sum(img) <= amin: continue

        cavities.append(img)

    return cavities
    
# If file is uploaded, read in the data and plot it
if uploaded_file is not None:
    with fits.open(uploaded_file) as hdul:
        data = hdul[0].data
        wcs = WCS(hdul[0].header)
        y_pred = np.zeros((128,128))
    
    # Make four columns for buttons
    _, col1, col2, col3, col4, col5, col6, _ = st.columns([bordersize,0.5,0.5,0.5,0.5,0.5,0.5,bordersize])
    col1.subheader("Input image")
    col3.subheader("Prediction")
    col5.subheader("Decomposed")
    
    with col1:
        # st.markdown("""<style>[data-baseweb="select"] {margin-top: -26px;}</style>""", unsafe_allow_html=True)
        max_scale = int(data.shape[0] // 128)
        scale = st.selectbox('Scale:',[f"{(i+1)*128}x{(i+1)*128}" for i in range(max_scale)], label_visibility="hidden")
        scale = int(scale.split("x")[0]) // 128
        # np.save("pred.npy", y_pred)
    
    with col3:
        detect = st.button('Detect')

    with col5:
        decompose = st.button('Docompose')
        
    # Make two columns for plots
    _, colA, colB, colC, _ = st.columns([bordersize,1,1,1,bordersize])
    
    image = np.log10(data+1)
    plot_image(image, scale)

    with col4:
        st.markdown("""<style>[data-baseweb="select"] {margin-top: -36px;}</style>""", unsafe_allow_html=True)
        threshold = st.slider("", 0.0, 1.0, 0.0, 0.05, label_visibility="hidden")
    
    if detect:
        data, wcs = cut(data, wcs, scale=scale)
        image = np.log10(data+1)
        
        y_pred = 0
        for j in [0,1,2,3]:
            rotated = np.rot90(image, j)
            pred = model.predict(rotated.reshape(1, 128, 128, 1)).reshape(128 ,128)
            pred = np.rot90(pred, -j)
            y_pred += pred / 4

        np.save("pred.npy", y_pred)

    try: y_pred = np.load("pred.npy")
    except: y_pred = np.zeros((128,128))
    y_pred = np.where(y_pred > threshold, y_pred, 0)
    np.save("thresh.npy", y_pred)
        
    plot_prediction(y_pred)

    # with colC:
    #     st.markdown("""<style>[data-baseweb="select"] {margin-top: -36px;}</style>""", unsafe_allow_html=True)

    if decompose:
        y_pred = np.load("thresh.npy")
        
        cavs = decompose_cavity(y_pred)

        ccd = CCDData(y_pred, unit="adu", wcs=wcs)
        ccd.write(f"predicted.fits", overwrite=True)
        image_decomposed = np.zeros((128,128))
        for i, cav in enumerate(cavs):
            ccd = CCDData(cav, unit="adu", wcs=wcs)
            ccd.write(f"predicted_{i+1}.fits", overwrite=True)
            image_decomposed += (i+1) * np.where(cav > 0, 1, 0)

        shutil.make_archive("predictions", 'zip', "predictions")
        np.save("decomposed.npy", image_decomposed)

    try: image_decomposed = np.load("decomposed.npy")
    except: image_decomposed = np.zeros((128,128))
    plot_decomposed(image_decomposed)

    with col6:
        pass
        # st.markdown("""<style>[data-baseweb="select"] {margin-top: 16px;}</style>""", unsafe_allow_html=True)
    #     # download = st.button('Download')
        download = st.download_button(label="Download", data=res, file_name="predicted.zip", mime="application/octet-stream")