|
import streamlit as st |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from matplotlib.patches import Rectangle |
|
from astropy.io import fits |
|
from astropy.wcs import WCS |
|
from astropy.nddata import Cutout2D, CCDData |
|
from tensorflow.keras.models import load_model |
|
|
|
st.set_option('deprecation.showPyplotGlobalUse', False) |
|
|
|
st.title("Cavity Detection Tool") |
|
|
|
model = load_model("CADET.hdf5") |
|
|
|
|
|
def plot_image(image_array, scale): |
|
|
|
plt.figure(figsize=(4, 4)) |
|
|
|
x0 = image_array.shape[0] // 2 - scale * 128 / 2 |
|
plt.imshow(image_array, origin="lower") |
|
plt.gca().add_patch(Rectangle((x0, x0), scale*128, scale*128, linewidth=1, edgecolor='w', facecolor='none')) |
|
plt.axis('off') |
|
st.pyplot(width=200) |
|
|
|
|
|
def plot_prediction(image_array, pred): |
|
|
|
plt.figure(figsize=(8, 4)) |
|
plt.subplot(1, 2, 1) |
|
plt.imshow(image_array, origin="lower") |
|
plt.axis('off') |
|
|
|
plt.subplot(1, 2, 2) |
|
plt.imshow(pred, origin="lower") |
|
plt.axis('off') |
|
st.pyplot(width=400) |
|
|
|
def cut(data0, wcs0, scale=1): |
|
shape = data0.shape[0] |
|
x0 = shape / 2 |
|
size = 128 * scale |
|
cutout = Cutout2D(data0, (x0, x0), (size, size), wcs=wcs0) |
|
data, wcs = cutout.data, cutout.wcs |
|
|
|
|
|
factor = size // 128 |
|
data = data.reshape(128, factor, 128, factor).mean(-1).mean(1) |
|
|
|
|
|
ra, dec = wcs.wcs_pix2world(np.array([[63, 63]]),0)[0] |
|
wcs.wcs.cdelt[0] = wcs.wcs.cdelt[0] * factor |
|
wcs.wcs.cdelt[1] = wcs.wcs.cdelt[1] * factor |
|
wcs.wcs.crval[0] = ra |
|
wcs.wcs.crval[1] = dec |
|
wcs.wcs.crpix[0] = 64 / factor |
|
wcs.wcs.crpix[1] = 64 / factor |
|
|
|
return data, wcs |
|
|
|
|
|
uploaded_file = st.file_uploader("Choose a FITS file", type=['fits']) |
|
|
|
|
|
if uploaded_file is not None: |
|
with fits.open(uploaded_file) as hdul: |
|
data = hdul[0].data |
|
wcs = WCS(hdul[0].header) |
|
|
|
|
|
scale = st.slider("Scale", 1, 4, 1, 1) |
|
|
|
plot_image(np.log10(data+1), scale) |
|
|
|
if st.button('Detect cavities'): |
|
data, wcs = cut(data, wcs, scale=scale) |
|
|
|
image_data = np.log10(data+1) |
|
|
|
y_pred = 0 |
|
for j in [0,1,2,3]: |
|
rotated = np.rot90(image_data, j) |
|
pred = model.predict(rotated.reshape(1, 128, 128, 1)).reshape(128 ,128) |
|
pred = np.rot90(pred, -j) |
|
y_pred += pred / 4 |
|
|
|
|
|
|
|
|
|
plot_prediction(image_data, y_pred) |
|
|
|
|
|
|
|
|
|
|
|
|