CADET / app.py
Plsek's picture
Update app.py
8e1fb51
raw
history blame
3.91 kB
# Basic libraries
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
# Astropy
from astropy.io import fits
from astropy.wcs import WCS
from astropy.nddata import Cutout2D, CCDData
# Tensorflow
from tensorflow.keras.models import load_model
model = load_model("CADET.hdf5")
# Streamlit
import streamlit as st
st.set_option('deprecation.showPyplotGlobalUse', False)
st.title("Cavity Detection Tool")
st.markdown("Cavity Detection Tool (CADET) is a machine learning pipeline trained to detect X-ray cavities from noisy Chandra images of early-type galaxies.To use this tool upload your image, select the scale of interest and make a prediction! If you use output of this tool in your research please cite [Plšek et al. 2023](https://arxiv.org/abs/2304.05457)")
# Create file uploader widget
uploaded_file = st.file_uploader("Choose a FITS file", type=['fits'])
# Define function to plot the uploaded image
def plot_image(image_array, scale):
plt.figure(figsize=(4, 4))
x0 = image_array.shape[0] // 2 - scale * 128 / 2
plt.imshow(image_array, origin="lower")
plt.gca().add_patch(Rectangle((x0, x0), scale*128, scale*128, linewidth=1, edgecolor='w', facecolor='none'))
plt.axis('off')
st.pyplot()
# Define function to plot the prediction
def plot_prediction(pred):
plt.figure(figsize=(4, 4))
plt.imshow(pred, origin="lower")
plt.axis('off')
# with col2:
st.pyplot()
# Cut input image and rebin it to 128x128 pixels
def cut(data0, wcs0, scale=1):
shape = data0.shape[0]
x0 = shape / 2
size = 128 * scale
cutout = Cutout2D(data0, (x0, x0), (size, size), wcs=wcs0)
data, wcs = cutout.data, cutout.wcs
# Regrid data
factor = size // 128
data = data.reshape(128, factor, 128, factor).mean(-1).mean(1)
# Regrid wcs
ra, dec = wcs.wcs_pix2world(np.array([[63, 63]]),0)[0]
wcs.wcs.cdelt[0] = wcs.wcs.cdelt[0] * factor
wcs.wcs.cdelt[1] = wcs.wcs.cdelt[1] * factor
wcs.wcs.crval[0] = ra
wcs.wcs.crval[1] = dec
wcs.wcs.crpix[0] = 64 / factor
wcs.wcs.crpix[1] = 64 / factor
return data, wcs
# If file is uploaded, read in the data and plot it
if uploaded_file is not None:
with fits.open(uploaded_file) as hdul:
data = hdul[0].data
wcs = WCS(hdul[0].header)
# Make two columns
col1, col2 = st.columns(2)
col1.subheader("Input image")
col2.subheader("CADET prediction")
with col1:
st.markdown(
"""<style>[data-baseweb="select"] {margin-top: -50px;}</style>""",
unsafe_allow_html=True
)
max_scale = int(data.shape[0] // 128)
# scale = st.slider("Scale", 1, max_scale, 1, 1)
scale = int(st.selectbox('Scale:',[i+1 for i in range(max_scale)], label_visibility="hidden"))
# Add a slider to change the scale
with col1:
plot_image(np.log10(data+1), scale)
with col2:
if st.button('Detect cavities'):
data, wcs = cut(data, wcs, scale=scale)
image = np.log10(data+1)
y_pred = 0
for j in [0,1,2,3]:
rotated = np.rot90(image, j)
pred = model.predict(rotated.reshape(1, 128, 128, 1)).reshape(128 ,128)
pred = np.rot90(pred, -j)
y_pred += pred / 4
# Thresholding
y_pred = np.where(y_pred > 0.4, y_pred, 0)
plot_prediction(y_pred)
# ccd = CCDData(pred, unit="adu", wcs=wcs)
# ccd.write(f"predicted.fits", overwrite=True)
# if st.button('Download FITS File'):
# with open('predicted.fits', 'rb') as f:
# data = f.read()
# st.download_button(label="Download", data=data, file_name="predicted.fits", mime="application/octet-stream")