File size: 5,799 Bytes
0ea9e86
 
 
 
7454012
6414f94
 
2680aa6
7454012
 
6414f94
 
3190548
d01a85f
 
7454012
 
 
6414f94
 
3ca195b
15c5050
125c6bf
6414f94
574acaf
a18eaf5
f27ae54
9bf2a3b
125c6bf
 
9bf2a3b
 
 
 
 
58a68f2
c1f2126
6414f94
d01a85f
ddce24a
c1f2126
d01a85f
 
2680aa6
c1f2126
3190548
0e52fa4
3190548
 
8120054
c1f2126
6414f94
 
0e52fa4
6414f94
7454012
6414f94
d26c581
 
 
 
 
 
7454012
d26c581
 
 
7454012
d26c581
 
 
 
 
 
 
 
44cb863
6414f94
15c5050
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b635f79
6414f94
 
 
 
 
9c67ee7
125c6bf
36e4579
125c6bf
 
58a68f2
0e52fa4
58a68f2
 
125c6bf
 
bb17b88
 
0e52fa4
58a68f2
 
 
8e1fb51
58a68f2
 
15c5050
bb17b88
a18eaf5
0e52fa4
58a68f2
 
 
bb17b88
 
0e52fa4
a18eaf5
 
 
e5f809c
faf2da7
 
 
 
 
 
 
 
 
 
574acaf
faf2da7
 
 
15c5050
9bf2a3b
 
 
 
fb43d84
9bf2a3b
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
# HuggingFace
from huggingface_hub import from_pretrained_keras
model = from_pretrained_keras("Plsek/CADET-v1")

# Basic libraries
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

# Astropy
from astropy.io import fits
from astropy.wcs import WCS
from astropy.nddata import Cutout2D, CCDData
from astropy.convolution import Gaussian2DKernel as Gauss
from astropy.convolution import convolve

# Streamlit
import streamlit as st
st.set_option('deprecation.showPyplotGlobalUse', False)

st.set_page_config(page_title="Cavity Detection Tool", layout="wide")

# st.title("Cavity Detection Tool")

bordersize = 0.6
_, col, _ = st.columns([bordersize, 3, bordersize])

with col:
    st.markdown("# Cavity Detection Tool")
    
    st.markdown("Cavity Detection Tool (CADET) is a machine learning pipeline trained to detect X-ray cavities from noisy Chandra images of early-type galaxies. To use this tool: upload your image, select the scale of interest, and make a prediction! If you use this tool for your research, please cite [Plšek et al. 2023](https://arxiv.org/abs/2304.05457)")
    
    st.markdown("Input images should be centered at the centre of the galaxy and point sources should be filled with surrounding background ([dmfilth](https://cxc.cfa.harvard.edu/ciao/ahelp/dmfilth.html)).")
    
    # Create file uploader widget
    # uploaded_file = st.file_uploader("Choose a FITS file", type=['fits'])

# Define function to plot the uploaded image
def plot_image(image, scale):
    plt.figure(figsize=(4, 4))

    x0 = image.shape[0] // 2 - scale * 128 / 2
    plt.imshow(image, origin="lower")
    plt.gca().add_patch(Rectangle((x0, x0), scale*128, scale*128, linewidth=1, edgecolor='w', facecolor='none'))
    
    plt.axis('off')
    with colA: st.pyplot()

# Define function to plot the prediction
def plot_prediction(pred):
    plt.figure(figsize=(4, 4))
    plt.imshow(pred, origin="lower")
    plt.axis('off')
    with colB: st.pyplot()

# Cut input image and rebin it to 128x128 pixels
def cut(data0, wcs0, scale=1):
    shape = data0.shape[0]
    x0 = shape / 2
    size = 128 * scale
    cutout = Cutout2D(data0, (x0, x0), (size, size), wcs=wcs0)
    data, wcs = cutout.data, cutout.wcs

    # Regrid data
    factor = size // 128
    data = data.reshape(128, factor, 128, factor).mean(-1).mean(1)
    
    # Regrid wcs
    ra, dec = wcs.wcs_pix2world(np.array([[63, 63]]),0)[0]
    wcs.wcs.cdelt[0] = wcs.wcs.cdelt[0] * factor
    wcs.wcs.cdelt[1] = wcs.wcs.cdelt[1] * factor
    wcs.wcs.crval[0] = ra
    wcs.wcs.crval[1] = dec
    wcs.wcs.crpix[0] = 64 / factor
    wcs.wcs.crpix[1] = 64 / factor

    return data, wcs

def decompose_cavity(pred, th2=0.7, amin=10):
    X, Y = pred.nonzero()
    data = np.array([X,Y]).reshape(2, -1)

    # DBSCAN CLUSTERING ALGORITHM
    try: clusters = DBSCAN(eps=1.5, min_samples=3).fit(data.T).labels_
    except: clusters = []

    N = len(set(clusters))
    cavities = []

    for i in range(N):
        img = np.zeros((128,128))
        b = clusters == i
        xi, yi = X[b], Y[b]
        img[xi, yi] = pred[xi, yi]

        # THRESHOLDING #2
        if not (img > th2).any(): continue

        # MINIMAL AREA
        if np.sum(img) <= amin: continue

        cavities.append(img)

    return cavities
    
# If file is uploaded, read in the data and plot it
if uploaded_file is not None:
    with fits.open(uploaded_file) as hdul:
        data = hdul[0].data
        wcs = WCS(hdul[0].header)
    
    # Make four columns for buttons
    _, col1, col2, col3, col4, col5, col6, _ = st.columns([bordersize,0.5,0.5,0.5,0.5,0.5,0.5,bordersize])
    col1.subheader("Input image")
    col3.subheader("Prediction")
    col5.subheader("Decomposed")
    
    with col2:
        # st.markdown("""<style>[data-baseweb="select"] {margin-top: -56px;}</style>""", unsafe_allow_html=True)
        max_scale = int(data.shape[0] // 128)
        # scale = int(st.selectbox('Scale:',[i+1 for i in range(max_scale)], label_visibility="hidden"))
        scale = st.selectbox('Scale:',[f"{(i+1)*128}x{(i+1)*128}" for i in range(max_scale)], label_visibility="hidden")
        scale = int(scale.split("x")[0]) // 128
    
    with col4:
        # st.markdown("""<style>[data-baseweb="select"] {margin-top: 16px;}</style>""", unsafe_allow_html=True)
        detect = st.button('Detect')

    with col6:
        decompose = st.button('Docompose')
        
    # Make two columns for plots
    _, colA, colB, colC, _ = st.columns([bordersize,1,1,1,bordersize])

    with colB:
        threshold = st.slider("", 0.0, 1.0, 0.4, 0.05, label_visibility="hidden")
    
    image = np.log10(data+1)
    plot_image(image, scale)
    
    if detect:
        data, wcs = cut(data, wcs, scale=scale)
        image = np.log10(data+1)
        
        y_pred = 0
        for j in [0,1,2,3]:
            rotated = np.rot90(image, j)
            pred = model.predict(rotated.reshape(1, 128, 128, 1)).reshape(128 ,128)
            pred = np.rot90(pred, -j)
            y_pred += pred / 4

        # Thresholding
        y_pred = np.where(y_pred > threshold, y_pred, 0)
                
        plot_prediction(y_pred)

        # if decompose:
        #     cavs = decompose_cavity(y_pred, )
        
    #     ccd = CCDData(y_pred, unit="adu", wcs=wcs)
    #     ccd.write("predicted.fits", overwrite=True)
    #     with open('predicted.fits', 'rb') as f:
    #         res = f.read()
        
    #     with col4:
    #         pass
    #         st.markdown("""<style>[data-baseweb="select"] {margin-top: 16px;}</style>""", unsafe_allow_html=True)
    #     #     # download = st.button('Download')
    #         download = st.download_button(label="Download", data=res, file_name="predicted.fits", mime="application/octet-stream")