File size: 3,116 Bytes
6414f94 2680aa6 6414f94 3190548 6414f94 d26c581 6414f94 3190548 707774a ddce24a 3190548 eba1cee 3190548 2680aa6 3190548 eba1cee 3190548 707774a ddce24a 6414f94 eba1cee 6414f94 d26c581 44cb863 6414f94 2680aa6 3190548 eba1cee 3190548 eba1cee 2e34964 eba1cee 6414f94 2d330f9 d26c581 eba1cee 8e671ec 2d330f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import streamlit as st
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from astropy.io import fits
from astropy.wcs import WCS
from astropy.nddata import Cutout2D, CCDData
from tensorflow.keras.models import load_model
st.set_option('deprecation.showPyplotGlobalUse', False)
st.title("Cavity Detection Tool")
model = load_model("CADET.hdf5")
# Define function to plot the uploaded image
def plot_image(image_array, scale):
# st.set_plot_config(plt, figsize=(4, 4))
plt.figure(figsize=(4, 4))
# plt.subplot(1, 2, 1)
x0 = image_array.shape[0] // 2 - scale * 128 / 2
plt.imshow(image_array, origin="lower")
plt.gca().add_patch(Rectangle((x0, x0), scale*128, scale*128, linewidth=1, edgecolor='w', facecolor='none'))
plt.axis('off')
st.pyplot(width=200)
# Define function to plot the prediction
def plot_prediction(image_array, pred):
# st.set_plot_config(plt, figsize=(8, 4))
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.imshow(image_array, origin="lower")
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(pred, origin="lower")
plt.axis('off')
st.pyplot(width=400)
def cut(data0, wcs0, scale=1):
shape = data0.shape[0]
x0 = shape / 2
size = 128 * scale
cutout = Cutout2D(data0, (x0, x0), (size, size), wcs=wcs0)
data, wcs = cutout.data, cutout.wcs
# REGRID DATA
factor = size // 128
data = data.reshape(128, factor, 128, factor).mean(-1).mean(1)
# REGIRD WCS
ra, dec = wcs.wcs_pix2world(np.array([[63, 63]]),0)[0]
wcs.wcs.cdelt[0] = wcs.wcs.cdelt[0] * factor
wcs.wcs.cdelt[1] = wcs.wcs.cdelt[1] * factor
wcs.wcs.crval[0] = ra
wcs.wcs.crval[1] = dec
wcs.wcs.crpix[0] = 64 / factor
wcs.wcs.crpix[1] = 64 / factor
return data, wcs
# Create file uploader widget
uploaded_file = st.file_uploader("Choose a FITS file", type=['fits'])
# If file is uploaded, read in the data and plot it
if uploaded_file is not None:
with fits.open(uploaded_file) as hdul:
data = hdul[0].data
wcs = WCS(hdul[0].header)
# Add a slider to change the scale
scale = st.slider("Scale", 1, 4, 1, 1)
plot_image(np.log10(data+1), scale)
if st.button('Detect cavities'):
data, wcs = cut(data, wcs, scale=scale)
image_data = np.log10(data+1)
y_pred = 0
for j in [0,1,2,3]:
rotated = np.rot90(image_data, j)
pred = model.predict(rotated.reshape(1, 128, 128, 1)).reshape(128 ,128)
pred = np.rot90(pred, -j)
y_pred += pred / 4
# ccd = CCDData(pred, unit="adu", wcs=wcs)
# ccd.write(f"predicted.fits", overwrite=True)
plot_prediction(image_data, y_pred)
# if st.button('Download FITS File'):
# with open('predicted.fits', 'rb') as f:
# data = f.read()
# st.download_button(label="Download", data=data, file_name="predicted.fits", mime="application/octet-stream")
|