CADET / app.py
Plsek's picture
Update app.py
8120054
raw
history blame
3.36 kB
import streamlit as st
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from astropy.io import fits
from astropy.wcs import WCS
from astropy.nddata import Cutout2D, CCDData
from tensorflow.keras.models import load_model
st.set_option('deprecation.showPyplotGlobalUse', False)
st.title("Cavity Detection Tool")
st.text("Cavity Detection Tool (CADET) is a machine learning pipeline trained to detect\n X-ray cavities from noisy Chandra images of early-type galaxies.To use this tool upload\n your image, select the scale of interest and make a prediction! If you use output of this\n tool in your research please cite [Plšek et al. 2023](https://arxiv.org/abs/2304.05457)")
model = load_model("CADET.hdf5")
# Create file uploader widget
uploaded_file = st.file_uploader("Choose a FITS file", type=['fits'])
col1, col2 = st.columns(2)
# Define function to plot the uploaded image
def plot_image(image_array, scale):
plt.figure(figsize=(4, 4))
x0 = image_array.shape[0] // 2 - scale * 128 / 2
plt.imshow(image_array, origin="lower")
plt.gca().add_patch(Rectangle((x0, x0), scale*128, scale*128, linewidth=1, edgecolor='w', facecolor='none'))
plt.axis('off')
with col1: st.pyplot()
# Define function to plot the prediction
def plot_prediction(pred):
plt.figure(figsize=(4, 4))
plt.imshow(pred, origin="lower")
plt.axis('off')
with col2: st.pyplot()
def cut(data0, wcs0, scale=1):
shape = data0.shape[0]
x0 = shape / 2
size = 128 * scale
cutout = Cutout2D(data0, (x0, x0), (size, size), wcs=wcs0)
data, wcs = cutout.data, cutout.wcs
# REGRID DATA
factor = size // 128
data = data.reshape(128, factor, 128, factor).mean(-1).mean(1)
# REGIRD WCS
ra, dec = wcs.wcs_pix2world(np.array([[63, 63]]),0)[0]
wcs.wcs.cdelt[0] = wcs.wcs.cdelt[0] * factor
wcs.wcs.cdelt[1] = wcs.wcs.cdelt[1] * factor
wcs.wcs.crval[0] = ra
wcs.wcs.crval[1] = dec
wcs.wcs.crpix[0] = 64 / factor
wcs.wcs.crpix[1] = 64 / factor
return data, wcs
# If file is uploaded, read in the data and plot it
if uploaded_file is not None:
with fits.open(uploaded_file) as hdul:
data = hdul[0].data
wcs = WCS(hdul[0].header)
# Add a slider to change the scale
scale = st.slider("Scale", 1, 4, 1, 1)
plot_image(np.log10(data+1), scale)
with col2:
if st.button('Detect cavities'):
data, wcs = cut(data, wcs, scale=scale)
image_data = np.log10(data+1)
y_pred = 0
for j in [0,1,2,3]:
rotated = np.rot90(image_data, j)
pred = model.predict(rotated.reshape(1, 128, 128, 1)).reshape(128 ,128)
pred = np.rot90(pred, -j)
y_pred += pred / 4
# ccd = CCDData(pred, unit="adu", wcs=wcs)
# ccd.write(f"predicted.fits", overwrite=True)
plot_prediction(y_pred)
# if st.button('Download FITS File'):
# with open('predicted.fits', 'rb') as f:
# data = f.read()
# st.download_button(label="Download", data=data, file_name="predicted.fits", mime="application/octet-stream")