import streamlit as st |
import numpy as np |
import matplotlib.pyplot as plt |
from matplotlib.patches import Rectangle |
from astropy.io import fits |
from astropy.wcs import WCS |
from astropy.nddata import Cutout2D, CCDData |
from tensorflow.keras.models import load_model |
st.set_option('deprecation.showPyplotGlobalUse', False) |
st.title("Cavity Detection Tool") |
model = load_model("CADET.hdf5") |
def plot_image(image_array, scale): |
plt.figure(figsize=(4, 4)) |
x0 = image_array.shape[0] // 2 - scale * 128 / 2 |
plt.imshow(image_array, origin="lower") |
plt.gca().add_patch(Rectangle((x0, x0), scale*128, scale*128, linewidth=1, edgecolor='w', facecolor='none')) |
plt.axis('off') |
st.pyplot(width=200) |
def plot_prediction(image_array, pred): |
plt.figure(figsize=(8, 4)) |
plt.subplot(1, 2, 1) |
plt.imshow(image_array, origin="lower") |
plt.axis('off') |
plt.subplot(1, 2, 2) |
plt.imshow(pred, origin="lower") |
plt.axis('off') |
st.pyplot(width=400) |
def cut(data0, wcs0, scale=1): |
shape = data0.shape[0] |
x0 = shape / 2 |
size = 128 * scale |
cutout = Cutout2D(data0, (x0, x0), (size, size), wcs=wcs0) |
data, wcs = cutout.data, cutout.wcs |
factor = size // 128 |
data = data.reshape(128, factor, 128, factor).mean(-1).mean(1) |
ra, dec = wcs.wcs_pix2world(np.array([[63, 63]]),0)[0] |
wcs.wcs.cdelt[0] = wcs.wcs.cdelt[0] * factor |
wcs.wcs.cdelt[1] = wcs.wcs.cdelt[1] * factor |
wcs.wcs.crval[0] = ra |
wcs.wcs.crval[1] = dec |
wcs.wcs.crpix[0] = 64 / factor |
wcs.wcs.crpix[1] = 64 / factor |
return data, wcs |
uploaded_file = st.file_uploader("Choose a FITS file", type=['fits']) |
if uploaded_file is not None: |
with fits.open(uploaded_file) as hdul: |
data = hdul[0].data |
wcs = WCS(hdul[0].header) |
scale = st.slider("Scale", 1, 4, 1, 1) |
plot_image(np.log10(data+1), scale) |
if st.button('Detect cavities'): |
data, wcs = cut(data, wcs, scale=scale) |
image_data = np.log10(data+1) |
y_pred = 0 |
for j in [0,1,2,3]: |
rotated = np.rot90(image_data, j) |
pred = model.predict(rotated.reshape(1, 128, 128, 1)).reshape(128 ,128) |
pred = np.rot90(pred, -j) |
y_pred += pred / 4 |
ccd = CCDData(pred, unit="adu", wcs=wcs) |
ccd.write(f"predicted.fits", overwrite=True) |
plot_prediction(image_data, y_pred) |
if st.button('Download FITS File'): |
with open('predicted.fits', 'rb') as f: |
data = f.read() |
st.download_button(label="Download", data=data, file_name="predicted.fits", mime="application/octet-stream") |