File size: 4,428 Bytes
7454012
6414f94
 
2680aa6
7454012
 
6414f94
 
3190548
d01a85f
 
7454012
248cf5e
31764f2
 
6414f94
7454012
 
6414f94
 
7454012
d26c581
6414f94
ff9f313
c1f2126
 
 
 
6414f94
d01a85f
ddce24a
c1f2126
d01a85f
 
2680aa6
c1f2126
3190548
0e52fa4
3190548
d01a85f
 
 
 
 
 
 
3190548
8120054
c1f2126
6414f94
 
0e52fa4
6414f94
7454012
6414f94
d26c581
 
 
 
 
 
7454012
d26c581
 
 
7454012
d26c581
 
 
 
 
 
 
 
44cb863
6414f94
b635f79
6414f94
 
 
 
 
9c67ee7
0ded826
0e52fa4
9c67ee7
778987b
e5f809c
7454012
 
0e52fa4
 
 
0ded826
0a4f4ea
 
0e52fa4
 
05f58f4
8e1fb51
0ded826
0e52fa4
 
d01a85f
 
 
0e52fa4
0a4f4ea
 
 
e5f809c
0a4f4ea
 
 
 
 
 
 
 
 
 
d01a85f
0a4f4ea
16cf1a4
 
 
0ded826
fb43d84
d01a85f
0ded826
 
5bbb151
0ded826
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# Basic libraries
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

# Astropy
from astropy.io import fits
from astropy.wcs import WCS
from astropy.nddata import Cutout2D, CCDData
from astropy.convolution import Gaussian2DKernel as Gauss
from astropy.convolution import convolve

# HuggingFace
from huggingface_hub import from_pretrained_keras
model = from_pretrained_keras("Plsek/CADET-v1")

# Streamlit
import streamlit as st
st.set_option('deprecation.showPyplotGlobalUse', False)


st.title("Cavity Detection Tool")

st.markdown("Cavity Detection Tool (CADET) is a machine learning pipeline trained to detect X-ray cavities from noisy Chandra images of early-type galaxies. To use this tool: upload your image, select the scale of interest, and make a prediction! If you use the output of this tool in your research please cite [Plšek et al. 2023](https://arxiv.org/abs/2304.05457)")

# Create file uploader widget
uploaded_file = st.file_uploader("Choose a FITS file", type=['fits'])

# Define function to plot the uploaded image
def plot_image(image, scale):
    plt.figure(figsize=(4, 4))

    x0 = image.shape[0] // 2 - scale * 128 / 2
    plt.imshow(image, origin="lower")
    plt.gca().add_patch(Rectangle((x0, x0), scale*128, scale*128, linewidth=1, edgecolor='w', facecolor='none'))
    
    plt.axis('off')
    with colA: st.pyplot()

# Define function to smooth image
def smooth_image(image, scale):
    smoothed = convolve(image, boundary = "wrap", nan_treatment="interpolate",
                    kernel = Gauss(x_stddev = 2, y_stddev = 2))

    return smoothed

# Define function to plot the prediction
def plot_prediction(pred):
    plt.figure(figsize=(4, 4))
    plt.imshow(pred, origin="lower")
    plt.axis('off')
    with colB: st.pyplot()

# Cut input image and rebin it to 128x128 pixels
def cut(data0, wcs0, scale=1):
    shape = data0.shape[0]
    x0 = shape / 2
    size = 128 * scale
    cutout = Cutout2D(data0, (x0, x0), (size, size), wcs=wcs0)
    data, wcs = cutout.data, cutout.wcs

    # Regrid data
    factor = size // 128
    data = data.reshape(128, factor, 128, factor).mean(-1).mean(1)
    
    # Regrid wcs
    ra, dec = wcs.wcs_pix2world(np.array([[63, 63]]),0)[0]
    wcs.wcs.cdelt[0] = wcs.wcs.cdelt[0] * factor
    wcs.wcs.cdelt[1] = wcs.wcs.cdelt[1] * factor
    wcs.wcs.crval[0] = ra
    wcs.wcs.crval[1] = dec
    wcs.wcs.crpix[0] = 64 / factor
    wcs.wcs.crpix[1] = 64 / factor

    return data, wcs

    
# If file is uploaded, read in the data and plot it
if uploaded_file is not None:
    with fits.open(uploaded_file) as hdul:
        data = hdul[0].data
        wcs = WCS(hdul[0].header)
    
    # Make four columns for buttons
    col1, col2, col3, col4 = st.columns(4)
    col1.subheader("Input image")
    col3.subheader("Prediction")

    # Add a slider to change the scale
    with col1:
        smooth = st.button("Smooth")
    
    with col2:
        st.markdown("""<style>[data-baseweb="select"] {margin-top: 17px;}</style>""", unsafe_allow_html=True)
        max_scale = int(data.shape[0] // 128)
        scale = int(st.selectbox('Scale:',[i+1 for i in range(max_scale)], label_visibility="hidden"))
    
    with col3:
        detect = st.button('Detect cavities')

    # Make two columns for plots
    colA, colB = st.columns(2)

    image = np.log10(data+1)
    if smooth: image = smooth_image(image, scale)
    plot_image(image, scale)
    
    if detect:
        data, wcs = cut(data, wcs, scale=scale)
        image = np.log10(data+1)
        
        y_pred = 0
        for j in [0,1,2,3]:
            rotated = np.rot90(image, j)
            pred = model.predict(rotated.reshape(1, 128, 128, 1)).reshape(128 ,128)
            pred = np.rot90(pred, -j)
            y_pred += pred / 4

        # Thresholding
        y_pred = np.where(y_pred > 0.4, y_pred, 0)

        plot_prediction(y_pred)

        ccd = CCDData(y_pred, unit="adu", wcs=wcs)
        ccd.write("predicted.fits", overwrite=True)
        with open('predicted.fits', 'rb') as f:
            res = f.read()
        
        with col4:
            pass
            # st.markdown("""<style>[data-baseweb="select"] {margin-top: 32px;}</style>""", unsafe_allow_html=True)
        #     # download = st.button('Download')
        #     download = st.download_button(label="Download", data=res, file_name="predicted.fits", mime="application/octet-stream")