import streamlit as st import numpy as np import matplotlib.pyplot as plt from matplotlib.patches import Rectangle from astropy.io import fits from astropy.wcs import WCS from astropy.nddata import Cutout2D, CCDData from tensorflow.keras.models import load_model st.set_option('deprecation.showPyplotGlobalUse', False) st.title("Cavity Detection Tool") st.text("Cavity Detection Tool (CADET) is a machine learning pipeline trained to detect\n X-ray cavities from noisy Chandra images of early-type galaxies.To use this tool upload\n your image, select the scale of interest and make a prediction! If you use output of this\n tool in your research please cite [Plšek et al. 2023](https://arxiv.org/abs/2304.05457)") model = load_model("CADET.hdf5") # Create file uploader widget uploaded_file = st.file_uploader("Choose a FITS file", type=['fits']) col1, col2 = st.columns(2) # Define function to plot the uploaded image def plot_image(image_array, scale): plt.figure(figsize=(4, 4)) x0 = image_array.shape[0] // 2 - scale * 128 / 2 plt.imshow(image_array, origin="lower") plt.gca().add_patch(Rectangle((x0, x0), scale*128, scale*128, linewidth=1, edgecolor='w', facecolor='none')) plt.axis('off') fig.set_size_inches((4,4)) with col1: st.pyplot() # Define function to plot the prediction def plot_prediction(image_array, pred): plt.figure(figsize=(4, 4)) # plt.subplot(1, 2, 1) # plt.imshow(image_array, origin="lower") # plt.axis('off') # plt.subplot(1, 2, 2) plt.imshow(pred, origin="lower") plt.axis('off') with col2: st.pyplot() def cut(data0, wcs0, scale=1): shape = data0.shape[0] x0 = shape / 2 size = 128 * scale cutout = Cutout2D(data0, (x0, x0), (size, size), wcs=wcs0) data, wcs = cutout.data, cutout.wcs # REGRID DATA factor = size // 128 data = data.reshape(128, factor, 128, factor).mean(-1).mean(1) # REGIRD WCS ra, dec = wcs.wcs_pix2world(np.array([[63, 63]]),0)[0] wcs.wcs.cdelt[0] = wcs.wcs.cdelt[0] * factor wcs.wcs.cdelt[1] = wcs.wcs.cdelt[1] * factor wcs.wcs.crval[0] = ra wcs.wcs.crval[1] = dec wcs.wcs.crpix[0] = 64 / factor wcs.wcs.crpix[1] = 64 / factor return data, wcs # If file is uploaded, read in the data and plot it if uploaded_file is not None: with fits.open(uploaded_file) as hdul: data = hdul[0].data wcs = WCS(hdul[0].header) # Add a slider to change the scale scale = st.slider("Scale", 1, 4, 1, 1) plot_image(np.log10(data+1), scale) if st.button('Detect cavities'): data, wcs = cut(data, wcs, scale=scale) image_data = np.log10(data+1) y_pred = 0 for j in [0,1,2,3]: rotated = np.rot90(image_data, j) pred = model.predict(rotated.reshape(1, 128, 128, 1)).reshape(128 ,128) pred = np.rot90(pred, -j) y_pred += pred / 4 # ccd = CCDData(pred, unit="adu", wcs=wcs) # ccd.write(f"predicted.fits", overwrite=True) plot_prediction(image_data, y_pred) # if st.button('Download FITS File'): # with open('predicted.fits', 'rb') as f: # data = f.read() # st.download_button(label="Download", data=data, file_name="predicted.fits", mime="application/octet-stream")