File size: 3,905 Bytes
7454012 6414f94 2680aa6 7454012 6414f94 3190548 7454012 6414f94 7454012 6414f94 7454012 6414f94 7454012 d26c581 6414f94 588ff75 c1f2126 6414f94 3190548 ddce24a c1f2126 eba1cee 3190548 2680aa6 c1f2126 3190548 58320fd 3190548 8120054 c1f2126 6414f94 58320fd 6414f94 7454012 6414f94 d26c581 7454012 d26c581 7454012 d26c581 44cb863 6414f94 b635f79 6414f94 9c67ee7 7454012 ce59402 8e1fb51 f3966b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
# Basic libraries
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
# Astropy
from astropy.io import fits
from astropy.wcs import WCS
from astropy.nddata import Cutout2D, CCDData
# Tensorflow
from tensorflow.keras.models import load_model
model = load_model("CADET.hdf5")
# Streamlit
import streamlit as st
st.set_option('deprecation.showPyplotGlobalUse', False)
st.title("Cavity Detection Tool")
st.markdown("Cavity Detection Tool (CADET) is a machine learning pipeline trained to detect X-ray cavities from noisy Chandra images of early-type galaxies.To use this tool upload your image, select the scale of interest and make a prediction! If you use output of this tool in your research please cite [Plšek et al. 2023](https://arxiv.org/abs/2304.05457)")
# Create file uploader widget
uploaded_file = st.file_uploader("Choose a FITS file", type=['fits'])
# Define function to plot the uploaded image
def plot_image(image_array, scale):
plt.figure(figsize=(4, 4))
x0 = image_array.shape[0] // 2 - scale * 128 / 2
plt.imshow(image_array, origin="lower")
plt.gca().add_patch(Rectangle((x0, x0), scale*128, scale*128, linewidth=1, edgecolor='w', facecolor='none'))
plt.axis('off')
st.pyplot()
# Define function to plot the prediction
def plot_prediction(pred):
plt.figure(figsize=(4, 4))
plt.imshow(pred, origin="lower")
plt.axis('off')
# with col2:
st.pyplot()
# Cut input image and rebin it to 128x128 pixels
def cut(data0, wcs0, scale=1):
shape = data0.shape[0]
x0 = shape / 2
size = 128 * scale
cutout = Cutout2D(data0, (x0, x0), (size, size), wcs=wcs0)
data, wcs = cutout.data, cutout.wcs
# Regrid data
factor = size // 128
data = data.reshape(128, factor, 128, factor).mean(-1).mean(1)
# Regrid wcs
ra, dec = wcs.wcs_pix2world(np.array([[63, 63]]),0)[0]
wcs.wcs.cdelt[0] = wcs.wcs.cdelt[0] * factor
wcs.wcs.cdelt[1] = wcs.wcs.cdelt[1] * factor
wcs.wcs.crval[0] = ra
wcs.wcs.crval[1] = dec
wcs.wcs.crpix[0] = 64 / factor
wcs.wcs.crpix[1] = 64 / factor
return data, wcs
# If file is uploaded, read in the data and plot it
if uploaded_file is not None:
with fits.open(uploaded_file) as hdul:
data = hdul[0].data
wcs = WCS(hdul[0].header)
# Make two columns
col1, col2 = st.columns(2)
col1.subheader("Input image")
col2.subheader("CADET prediction")
with col1:
st.markdown(
"""<style>[data-baseweb="select"] {margin-top: -50px;}</style>""",
unsafe_allow_html=True
)
max_scale = int(data.shape[0] // 128)
# scale = st.slider("Scale", 1, max_scale, 1, 1)
scale = int(st.selectbox('Scale:',[i+1 for i in range(max_scale)], label_visibility="hidden"))
# Add a slider to change the scale
with col1:
plot_image(np.log10(data+1), scale)
with col2:
if st.button('Detect cavities'):
data, wcs = cut(data, wcs, scale=scale)
image = np.log10(data+1)
y_pred = 0
for j in [0,1,2,3]:
rotated = np.rot90(image, j)
pred = model.predict(rotated.reshape(1, 128, 128, 1)).reshape(128 ,128)
pred = np.rot90(pred, -j)
y_pred += pred / 4
# Thresholding
y_pred = np.where(y_pred > 0.4, y_pred, 0)
plot_prediction(y_pred)
# ccd = CCDData(pred, unit="adu", wcs=wcs)
# ccd.write(f"predicted.fits", overwrite=True)
# if st.button('Download FITS File'):
# with open('predicted.fits', 'rb') as f:
# data = f.read()
# st.download_button(label="Download", data=data, file_name="predicted.fits", mime="application/octet-stream")
|