|
import streamlit as st |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from astropy.io import fits |
|
from astropy.wcs import WCS |
|
from astropy.nddata import Cutout2D |
|
from tensorflow.keras.models import load_model |
|
|
|
st.set_option('deprecation.showPyplotGlobalUse', False) |
|
|
|
st.title("FITS Image Viewer") |
|
|
|
model = load_model("CADET.hdf5") |
|
|
|
|
|
def plot_image(image_array, pred): |
|
plt.figure(figsize=(10, 5)) |
|
plt.subplot(1, 2, 1) |
|
plt.imshow(image_array, origin="lower") |
|
plt.axis('off') |
|
|
|
plt.subplot(1, 2, 2) |
|
plt.imshow(pred, origin="lower") |
|
plt.axis('off') |
|
st.pyplot() |
|
|
|
def cut(data0, wcs0, scale=1): |
|
shape = data0.shape[0] |
|
x0 = shape / 2 |
|
size = 128 * scale |
|
cutout = Cutout2D(data0, (x0, x0), (size, size), wcs=wcs0) |
|
return cutout.data, cutout.wcs |
|
|
|
|
|
uploaded_file = st.file_uploader("Choose a FITS file", type=['fits']) |
|
|
|
|
|
if uploaded_file is not None: |
|
with fits.open(uploaded_file) as hdul: |
|
data = hdul[0].data |
|
wcs = WCS(hdul[0].header) |
|
data, wcs = cut(data, wcs, scale=1) |
|
|
|
image_data = np.log10(data+1) |
|
pred = model.predict(image_data.reshape(1, 128, 128, 1)).reshape(128 ,128) |
|
|
|
plot_image(image_data, pred) |
|
|