File size: 3,415 Bytes
6414f94
 
 
2680aa6
6414f94
 
3190548
6414f94
 
 
 
d26c581
6414f94
c1f2126
 
 
6414f94
 
c1f2126
 
 
 
 
6414f94
3190548
ddce24a
c1f2126
eba1cee
3190548
2680aa6
c1f2126
3190548
c1f2126
 
3190548
 
 
c1f2126
 
 
 
6414f94
c1f2126
6414f94
 
c1f2126
6414f94
 
d26c581
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44cb863
6414f94
 
 
 
 
 
 
 
2680aa6
 
 
3190548
 
eba1cee
3190548
 
 
 
eba1cee
 
2e34964
eba1cee
 
 
6414f94
2d330f9
 
d26c581
eba1cee
8e671ec
2d330f9
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import streamlit as st
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from astropy.io import fits
from astropy.wcs import WCS
from astropy.nddata import Cutout2D, CCDData
from tensorflow.keras.models import load_model

st.set_option('deprecation.showPyplotGlobalUse', False)

st.title("Cavity Detection Tool")

st.text("Cavity Detection Tool (CADET) is a machine learning pipeline trained to detect X-ray cavities from noisy Chandra images of early-type galaxies.\
To use this tool upload your image, select the scale of interest and make a prediction! If you use output of this tool in your research please cite [Plšek et al. 2023](https://arxiv.org/abs/2304.05457)")

model = load_model("CADET.hdf5")

# Create file uploader widget
uploaded_file = st.file_uploader("Choose a FITS file", type=['fits'])

col1, col2 = st.columns(2)

# Define function to plot the uploaded image
def plot_image(image_array, scale):
    plt.figure(figsize=(4, 4))

    x0 = image_array.shape[0] // 2 - scale * 128 / 2
    plt.imshow(image_array, origin="lower")
    plt.gca().add_patch(Rectangle((x0, x0), scale*128, scale*128, linewidth=1, edgecolor='w', facecolor='none'))
    
    plt.axis('off')
    fig.set_size_inches((4,4))
    st.pyplot()

# Define function to plot the prediction
def plot_prediction(image_array, pred):
    plt.figure(figsize=(4, 4))
    # plt.subplot(1, 2, 1)
    # plt.imshow(image_array, origin="lower")
    # plt.axis('off')

    # plt.subplot(1, 2, 2)
    plt.imshow(pred, origin="lower")
    plt.axis('off')
    col2.pyplot()

def cut(data0, wcs0, scale=1):
    shape = data0.shape[0]
    x0 = shape / 2
    size = 128 * scale
    cutout = Cutout2D(data0, (x0, x0), (size, size), wcs=wcs0)
    data, wcs = cutout.data, cutout.wcs

    # REGRID DATA
    factor = size // 128
    data = data.reshape(128, factor, 128, factor).mean(-1).mean(1)
    
    # REGIRD WCS
    ra, dec = wcs.wcs_pix2world(np.array([[63, 63]]),0)[0]
    wcs.wcs.cdelt[0] = wcs.wcs.cdelt[0] * factor
    wcs.wcs.cdelt[1] = wcs.wcs.cdelt[1] * factor
    wcs.wcs.crval[0] = ra
    wcs.wcs.crval[1] = dec
    wcs.wcs.crpix[0] = 64 / factor
    wcs.wcs.crpix[1] = 64 / factor

    return data, wcs


# If file is uploaded, read in the data and plot it
if uploaded_file is not None:
    with fits.open(uploaded_file) as hdul:
        data = hdul[0].data
        wcs = WCS(hdul[0].header)

        # Add a slider to change the scale
        scale = st.slider("Scale", 1, 4, 1, 1)

        plot_image(np.log10(data+1), scale)

        if st.button('Detect cavities'):                
            data, wcs = cut(data, wcs, scale=scale)
    
            image_data = np.log10(data+1)

            y_pred = 0
            for j in [0,1,2,3]:
                rotated = np.rot90(image_data, j)
                pred = model.predict(rotated.reshape(1, 128, 128, 1)).reshape(128 ,128)
                pred = np.rot90(pred, -j)
                y_pred += pred / 4

            # ccd = CCDData(pred, unit="adu", wcs=wcs)
            # ccd.write(f"predicted.fits", overwrite=True)

            plot_prediction(image_data, y_pred)

            # if st.button('Download FITS File'):
            #     with open('predicted.fits', 'rb') as f:
            #         data = f.read()
            #     st.download_button(label="Download", data=data, file_name="predicted.fits", mime="application/octet-stream")