File size: 3,905 Bytes
7454012
6414f94
 
2680aa6
7454012
 
6414f94
 
3190548
7454012
 
6414f94
7454012
6414f94
7454012
 
6414f94
 
7454012
d26c581
6414f94
588ff75
c1f2126
 
 
 
6414f94
3190548
ddce24a
c1f2126
eba1cee
3190548
2680aa6
c1f2126
3190548
58320fd
3190548
 
8120054
c1f2126
6414f94
 
58320fd
 
6414f94
7454012
6414f94
d26c581
 
 
 
 
 
7454012
d26c581
 
 
7454012
d26c581
 
 
 
 
 
 
 
44cb863
6414f94
b635f79
6414f94
 
 
 
 
9c67ee7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7454012
 
 
ce59402
8e1fb51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f3966b1
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# Basic libraries
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

# Astropy
from astropy.io import fits
from astropy.wcs import WCS
from astropy.nddata import Cutout2D, CCDData

# Tensorflow
from tensorflow.keras.models import load_model
model = load_model("CADET.hdf5")

# Streamlit
import streamlit as st
st.set_option('deprecation.showPyplotGlobalUse', False)


st.title("Cavity Detection Tool")

st.markdown("Cavity Detection Tool (CADET) is a machine learning pipeline trained to detect X-ray cavities from noisy Chandra images of early-type galaxies.To use this tool upload your image, select the scale of interest and make a prediction! If you use output of this tool in your research please cite [Plšek et al. 2023](https://arxiv.org/abs/2304.05457)")

# Create file uploader widget
uploaded_file = st.file_uploader("Choose a FITS file", type=['fits'])

# Define function to plot the uploaded image
def plot_image(image_array, scale):
    plt.figure(figsize=(4, 4))

    x0 = image_array.shape[0] // 2 - scale * 128 / 2
    plt.imshow(image_array, origin="lower")
    plt.gca().add_patch(Rectangle((x0, x0), scale*128, scale*128, linewidth=1, edgecolor='w', facecolor='none'))
    
    plt.axis('off')
    st.pyplot()

# Define function to plot the prediction
def plot_prediction(pred):
    plt.figure(figsize=(4, 4))
    plt.imshow(pred, origin="lower")
    plt.axis('off')
    # with col2: 
    st.pyplot()

# Cut input image and rebin it to 128x128 pixels
def cut(data0, wcs0, scale=1):
    shape = data0.shape[0]
    x0 = shape / 2
    size = 128 * scale
    cutout = Cutout2D(data0, (x0, x0), (size, size), wcs=wcs0)
    data, wcs = cutout.data, cutout.wcs

    # Regrid data
    factor = size // 128
    data = data.reshape(128, factor, 128, factor).mean(-1).mean(1)
    
    # Regrid wcs
    ra, dec = wcs.wcs_pix2world(np.array([[63, 63]]),0)[0]
    wcs.wcs.cdelt[0] = wcs.wcs.cdelt[0] * factor
    wcs.wcs.cdelt[1] = wcs.wcs.cdelt[1] * factor
    wcs.wcs.crval[0] = ra
    wcs.wcs.crval[1] = dec
    wcs.wcs.crpix[0] = 64 / factor
    wcs.wcs.crpix[1] = 64 / factor

    return data, wcs

    
# If file is uploaded, read in the data and plot it
if uploaded_file is not None:
    with fits.open(uploaded_file) as hdul:
        data = hdul[0].data
        wcs = WCS(hdul[0].header)
    
    # Make two columns
    col1, col2 = st.columns(2)
    col1.subheader("Input image")
    col2.subheader("CADET prediction")
    
    with col1:
        st.markdown(
            """<style>[data-baseweb="select"] {margin-top: -50px;}</style>""",
            unsafe_allow_html=True
        )
    
        max_scale = int(data.shape[0] // 128)
        # scale = st.slider("Scale", 1, max_scale, 1, 1)
        scale = int(st.selectbox('Scale:',[i+1 for i in range(max_scale)], label_visibility="hidden"))

    # Add a slider to change the scale
    with col1:
        plot_image(np.log10(data+1), scale)

    with col2:
        if st.button('Detect cavities'):
            data, wcs = cut(data, wcs, scale=scale)

            image = np.log10(data+1)
            
            y_pred = 0
            for j in [0,1,2,3]:
                rotated = np.rot90(image, j)
                pred = model.predict(rotated.reshape(1, 128, 128, 1)).reshape(128 ,128)
                pred = np.rot90(pred, -j)
                y_pred += pred / 4
    
            # Thresholding
            y_pred = np.where(y_pred > 0.4, y_pred, 0)

            plot_prediction(y_pred)
    
            # ccd = CCDData(pred, unit="adu", wcs=wcs)
            # ccd.write(f"predicted.fits", overwrite=True)
            
            # if st.button('Download FITS File'):
            #     with open('predicted.fits', 'rb') as f:
            #         data = f.read()
            #     st.download_button(label="Download", data=data, file_name="predicted.fits", mime="application/octet-stream")