File size: 8,129 Bytes
0012f0e
57f9dda
 
0ea9e86
7454012
e52c641
2bb6d79
6414f94
b4f54b4
6414f94
8111624
2680aa6
7454012
 
6414f94
 
3190548
d01a85f
 
7454012
1197a1f
 
 
7454012
 
6414f94
 
 
d01a85f
ddce24a
d01a85f
 
2680aa6
3190548
32d9557
3190548
 
8120054
c1f2126
6414f94
 
32d9557
6414f94
29cd1e9
47ec853
29cd1e9
d0a6846
8111624
7cb9bcf
8111624
 
 
48b10cd
 
8111624
29cd1e9
 
 
0755e66
6414f94
d26c581
 
 
 
 
 
7454012
d26c581
 
 
7454012
d26c581
 
 
 
 
 
 
 
44cb863
6414f94
0755e66
5d6cf3b
9219e24
a8b8828
 
 
 
 
 
 
 
 
 
4161769
a8b8828
0755e66
38ebf86
15c5050
 
 
 
0755e66
7e4e8dd
15c5050
 
 
 
 
 
 
 
 
 
 
0755e66
7c32f83
15c5050
0755e66
7c32f83
15c5050
 
 
0755e66
393d78f
 
 
 
 
 
 
 
17b9e5d
6b286d7
393d78f
e9ef2e7
0755e66
 
e9ef2e7
 
 
cb61555
 
e9ef2e7
0755e66
 
e9ef2e7
 
 
 
992ff70
17c52a0
992ff70
0012f0e
e9ef2e7
 
992ff70
0012f0e
b046ded
 
 
 
 
 
b635f79
6414f94
 
 
 
 
9c67ee7
8516575
36e4579
125c6bf
 
58a68f2
c6b81d0
0e52fa4
36a5e8a
b046ded
125c6bf
17c52a0
bb17b88
8e1fb51
0755e66
17b9e5d
0755e66
 
 
c6b81d0
0755e66
17c52a0
0755e66
 
e842a8b
15c5050
bb17b88
a18eaf5
58a68f2
bb17b88
32d9557
a8b8828
e3ab055
4161769
7bcd4a9
9219e24
 
 
c6b81d0
d61ba29
393d78f
9219e24
d61ba29
 
 
091111e
d61ba29
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
# HuggingFace Hub
from huggingface_hub import from_pretrained_keras
model = from_pretrained_keras("Plsek/CADET-v1")

# Basic libraries
import os
import shutil
import numpy as np
from scipy.ndimage import center_of_mass
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from matplotlib.patches import Rectangle

# Astropy
from astropy.io import fits
from astropy.wcs import WCS
from astropy.nddata import Cutout2D, CCDData
from astropy.convolution import Gaussian2DKernel as Gauss
from astropy.convolution import convolve

# Scikit-learn
from sklearn.cluster import DBSCAN

# Streamlit
import streamlit as st
st.set_option('deprecation.showPyplotGlobalUse', False)

# Define function to plot the uploaded image
def plot_image(image, scale):
    plt.figure(figsize=(4, 4))
    x0 = image.shape[0] // 2 - scale * 128 / 2
    plt.imshow(image, origin="lower")
    plt.gca().add_patch(Rectangle((x0, x0), scale*128, scale*128, linewidth=1, edgecolor='w', facecolor='none'))
    plt.axis('off')
    with colA: st.pyplot()

# Define function to plot the prediction
def plot_prediction(pred):
    plt.figure(figsize=(4, 4))
    plt.imshow(pred, origin="lower")
    plt.axis('off')
    with colB: st.pyplot()

# Define function to plot the decomposed prediction
def plot_decomposed(decomposed):
    plt.figure(figsize=(4, 4))
    plt.imshow(decomposed, origin="lower") #, norm=LogNorm())

    N = int(np.max(decomposed))
    for i in range(N):
        new = np.where(decomposed == i+1, 1, 0)
        x0, y0 = center_of_mass(new)
        color = "white" if i < N//2 else "black"
        plt.text(y0, x0, f"{i+1}", ha="center", va="center", fontsize=15, color=color)
    
    plt.axis('off')
    with colC: st.pyplot()
        
# Define function to cut input image and rebin it to 128x128 pixels
def cut(data0, wcs0, scale=1):
    shape = data0.shape[0]
    x0 = shape / 2
    size = 128 * scale
    cutout = Cutout2D(data0, (x0, x0), (size, size), wcs=wcs0)
    data, wcs = cutout.data, cutout.wcs

    # Regrid data
    factor = size // 128
    data = data.reshape(128, factor, 128, factor).mean(-1).mean(1)
    
    # Regrid wcs
    ra, dec = wcs.wcs_pix2world(np.array([[63, 63]]),0)[0]
    wcs.wcs.cdelt[0] = wcs.wcs.cdelt[0] * factor
    wcs.wcs.cdelt[1] = wcs.wcs.cdelt[1] * factor
    wcs.wcs.crval[0] = ra
    wcs.wcs.crval[1] = dec
    wcs.wcs.crpix[0] = 64 / factor
    wcs.wcs.crpix[1] = 64 / factor

    return data, wcs

# Define function to apply cutting and produce a prediction
@st.cache
def cut_n_predict(data, wcs, scale):
    data, wcs = cut(data, wcs, scale=scale)
    image = np.log10(data+1)
    
    y_pred = 0
    for j in [0,1,2,3]:
        rotated = np.rot90(image, j)
        pred = model.predict(rotated.reshape(1, 128, 128, 1)).reshape(128 ,128)
        pred = np.rot90(pred, -j)
        y_pred += pred / 4

    return y_pred, wcs

# Define function to decompose prediction into individual cavities
@st.cache
def decompose_cavity(pred, th2=0.7, amin=10):
    X, Y = pred.nonzero()
    data = np.array([X,Y]).reshape(2, -1)

    # DBSCAN clustering
    try: clusters = DBSCAN(eps=1.0, min_samples=3).fit(data.T).labels_
    except: clusters = []

    N = len(set(clusters))
    cavities = []

    for i in range(N):
        img = np.zeros((128,128))
        b = clusters == i
        xi, yi = X[b], Y[b]
        img[xi, yi] = pred[xi, yi]

        # Thresholding #2
        if not (img > th2).any(): continue

        # Minimal area
        if np.sum(img) <= amin: continue

        cavities.append(img)

    # Save raw and decomposed predictions to predictions folder
    ccd = CCDData(pred, unit="adu", wcs=wcs)
    ccd.write(f"predictions/predicted.fits", overwrite=True)
    image_decomposed = np.zeros((128,128))
    for i, cav in enumerate(cavities):
        ccd = CCDData(cav, unit="adu", wcs=wcs)
        ccd.write(f"predictions/predicted_{i+1}.fits", overwrite=True)
        image_decomposed += (i+1) * np.where(cav > 0, 1, 0)

    # shutil.make_archive("predictions", 'zip', "predictions")
    
    return image_decomposed

# Use wide layout and create columns
st.set_page_config(page_title="Cavity Detection Tool", layout="wide")
bordersize = 0.6
_, col, _ = st.columns([bordersize, 3, bordersize])

os.system("mkdir -p predictions")

with col:
    # Create heading and description
    st.markdown("# Cavity Detection Tool")    
    st.markdown("Cavity Detection Tool (CADET) is a machine learning pipeline trained to detect X-ray cavities from noisy Chandra images of early-type galaxies.")
    st.markdown("To use this tool: upload your image, select the scale of interest, and make a prediction!")
    st.markdown("Input images should be centred at the centre of the galaxy and point sources should be filled with surrounding background ([dmfilth](https://cxc.cfa.harvard.edu/ciao/ahelp/dmfilth.html)).")
    st.markdown("If you use this tool for your research, please cite [Plšek et al. 2023](https://arxiv.org/abs/2304.05457)")

_, col_1, _, col_2, _ = st.columns([bordersize, 2.0, 0.5, 0.5, bordersize])

with col_1:
    # Create file uploader widget
    uploaded_file = st.file_uploader("Choose a FITS file", type=['fits'])

with col_2:
    st.markdown("# Examples")
    NGC4649 = st.button("NGC4649")
    NGC5813 = st.button("NGC5813")

if NGC4649: uploaded_file = "NGC4649_example.fits"
elif NGC5813: uploaded_file = "NGC5813_example.fits"
    
# If file is uploaded, read in the data and plot it
if uploaded_file is not None:
    with fits.open(uploaded_file) as hdul:
        data = hdul[0].data
        wcs = WCS(hdul[0].header)
    
    # Make six columns for buttons
    _, col1, col2, col3, col4, col5, col6, _ = st.columns([bordersize,0.5,0.5,0.5,0.5,0.5,0.5,bordersize])
    col1.subheader("Input image")
    col3.subheader("Prediction")
    col5.subheader("Decomposed")
    col6.subheader("")
    
    with col1:
        st.markdown("""<style>[data-baseweb="select"] {margin-top: -46px;}</style>""", unsafe_allow_html=True)
        max_scale = int(data.shape[0] // 128)
        scale = st.selectbox('Scale:',[f"{(i+1)*128}x{(i+1)*128}" for i in range(max_scale)], label_visibility="hidden")
        scale = int(scale.split("x")[0]) // 128

    # Detect button
    with col3: detect = st.button('Detect', key="detect")

    # Threshold slider
    with col4:
        st.markdown("")
        # st.markdown("""<style>[data-baseweb="select"] {margin-top: -36px;}</style>""", unsafe_allow_html=True)
        threshold = st.slider("Threshold", 0.0, 1.0, 0.0, 0.05) #, label_visibility="hidden")
        
    # Decompose button
    with col5: decompose = st.button('Decompose', key="decompose")
        
    # Make two columns for plots
    _, colA, colB, colC, _ = st.columns([bordersize,1,1,1,bordersize])
    
    image = np.log10(data+1)
    plot_image(image, scale)
    
    if detect or threshold:
        y_pred, wcs = cut_n_predict(data, wcs, scale)
        
        y_pred_th = np.where(y_pred > threshold, y_pred, 0)
                
        plot_prediction(y_pred_th)

        if decompose or st.session_state.get("download", False):            
            image_decomposed = decompose_cavity(y_pred_th)
    
            plot_decomposed(image_decomposed)

            with col6:
                st.markdown("<br style='margin:4px 0'>", unsafe_allow_html=True)
                # st.markdown("""<style>[data-baseweb="select"] {margin-top: 16px;}</style>""", unsafe_allow_html=True)
                fname = uploaded_file.name.strip(".fits")
        
                # if st.session_state.get("download", False):
        
                shutil.make_archive("predictions", 'zip', "predictions")
                with open('predictions.zip', 'rb') as f:
                    res = f.read()
                
                download = st.download_button(label="Download", data=res, key="download", 
                                              file_name=f'{fname}_{int(scale*128)}.zip', 
                                              # disabled=st.session_state.get("disabled", True), 
                                              mime="application/octet-stream")