import streamlit as st import numpy as np import matplotlib.pyplot as plt from matplotlib.patches import Rectangle from astropy.io import fits from astropy.wcs import WCS from astropy.nddata import Cutout2D, CCDData from tensorflow.keras.models import load_model st.set_option('deprecation.showPyplotGlobalUse', False) st.title("Cavity Detection Tool") model = load_model("CADET.hdf5") # Define function to plot the uploaded image def plot_image(image_array, scale): # st.set_plot_config(plt, figsize=(4, 4)) plt.figure(figsize=(4, 4)) # plt.subplot(1, 2, 1) x0 = image_array.shape[0] // 2 - scale * 128 / 2 plt.imshow(image_array, origin="lower") plt.gca().add_patch(Rectangle((x0, x0), scale*128, scale*128, linewidth=1, edgecolor='w', facecolor='none')) plt.axis('off') st.pyplot(width=200) # Define function to plot the prediction def plot_prediction(image_array, pred): # st.set_plot_config(plt, figsize=(8, 4)) plt.figure(figsize=(8, 4)) plt.subplot(1, 2, 1) plt.imshow(image_array, origin="lower") plt.axis('off') plt.subplot(1, 2, 2) plt.imshow(pred, origin="lower") plt.axis('off') st.pyplot(width=400) def cut(data0, wcs0, scale=1): shape = data0.shape[0] x0 = shape / 2 size = 128 * scale cutout = Cutout2D(data0, (x0, x0), (size, size), wcs=wcs0) data, wcs = cutout.data, cutout.wcs # REGRID DATA factor = size // 128 data = data.reshape(128, factor, 128, factor).mean(-1).mean(1) # REGIRD WCS ra, dec = wcs.wcs_pix2world(np.array([[63, 63]]),0)[0] wcs.wcs.cdelt[0] = wcs.wcs.cdelt[0] * factor wcs.wcs.cdelt[1] = wcs.wcs.cdelt[1] * factor wcs.wcs.crval[0] = ra wcs.wcs.crval[1] = dec wcs.wcs.crpix[0] = 64 / factor wcs.wcs.crpix[1] = 64 / factor return data, wcs # Create file uploader widget uploaded_file = st.file_uploader("Choose a FITS file", type=['fits']) # If file is uploaded, read in the data and plot it if uploaded_file is not None: with fits.open(uploaded_file) as hdul: data = hdul[0].data wcs = WCS(hdul[0].header) # Add a slider to change the scale scale = st.slider("Scale", 1, 4, 1, 1) plot_image(np.log10(data+1), scale) if st.button('Detect cavities'): data, wcs = cut(data, wcs, scale=scale) image_data = np.log10(data+1) y_pred = 0 for j in [0,1,2,3]: rotated = np.rot90(image_data, j) pred = model.predict(rotated.reshape(1, 128, 128, 1)).reshape(128 ,128) pred = np.rot90(pred, -j) y_pred += pred / 4 # ccd = CCDData(pred, unit="adu", wcs=wcs) # ccd.write(f"predicted.fits", overwrite=True) plot_prediction(image_data, y_pred) # if st.button('Download FITS File'): # with open('predicted.fits', 'rb') as f: # data = f.read() # st.download_button(label="Download", data=data, file_name="predicted.fits", mime="application/octet-stream")