|
|
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from matplotlib.patches import Rectangle |
|
|
|
|
|
from astropy.io import fits |
|
from astropy.wcs import WCS |
|
from astropy.nddata import Cutout2D, CCDData |
|
|
|
|
|
from tensorflow.keras.models import load_model |
|
model = load_model("CADET.hdf5") |
|
|
|
|
|
import streamlit as st |
|
st.set_option('deprecation.showPyplotGlobalUse', False) |
|
|
|
|
|
st.title("Cavity Detection Tool") |
|
|
|
st.markdown("Cavity Detection Tool (CADET) is a machine learning pipeline trained to detect X-ray cavities from noisy Chandra images of early-type galaxies.To use this tool upload your image, select the scale of interest and make a prediction! If you use output of this tool in your research please cite [Plšek et al. 2023](https://arxiv.org/abs/2304.05457)") |
|
|
|
|
|
uploaded_file = st.file_uploader("Choose a FITS file", type=['fits']) |
|
|
|
|
|
def plot_image(image_array, scale): |
|
plt.figure(figsize=(4, 4)) |
|
|
|
x0 = image_array.shape[0] // 2 - scale * 128 / 2 |
|
plt.imshow(image_array, origin="lower") |
|
plt.gca().add_patch(Rectangle((x0, x0), scale*128, scale*128, linewidth=1, edgecolor='w', facecolor='none')) |
|
|
|
plt.axis('off') |
|
st.pyplot() |
|
|
|
|
|
def plot_prediction(pred): |
|
plt.figure(figsize=(4, 4)) |
|
plt.imshow(pred, origin="lower") |
|
plt.axis('off') |
|
|
|
st.pyplot() |
|
|
|
|
|
def cut(data0, wcs0, scale=1): |
|
shape = data0.shape[0] |
|
x0 = shape / 2 |
|
size = 128 * scale |
|
cutout = Cutout2D(data0, (x0, x0), (size, size), wcs=wcs0) |
|
data, wcs = cutout.data, cutout.wcs |
|
|
|
|
|
factor = size // 128 |
|
data = data.reshape(128, factor, 128, factor).mean(-1).mean(1) |
|
|
|
|
|
ra, dec = wcs.wcs_pix2world(np.array([[63, 63]]),0)[0] |
|
wcs.wcs.cdelt[0] = wcs.wcs.cdelt[0] * factor |
|
wcs.wcs.cdelt[1] = wcs.wcs.cdelt[1] * factor |
|
wcs.wcs.crval[0] = ra |
|
wcs.wcs.crval[1] = dec |
|
wcs.wcs.crpix[0] = 64 / factor |
|
wcs.wcs.crpix[1] = 64 / factor |
|
|
|
return data, wcs |
|
|
|
|
|
|
|
if uploaded_file is not None: |
|
with fits.open(uploaded_file) as hdul: |
|
data = hdul[0].data |
|
wcs = WCS(hdul[0].header) |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
col1.subheader("Input image") |
|
col2.subheader("CADET prediction") |
|
|
|
with col1: |
|
st.markdown( |
|
"""<style>[data-baseweb="select"] {margin-top: -50px;}</style>""", |
|
unsafe_allow_html=True |
|
) |
|
|
|
max_scale = int(data.shape[0] // 128) |
|
|
|
scale = int(st.selectbox('Scale:',[i+1 for i in range(max_scale)], label_visibility="hidden")) |
|
|
|
|
|
with col1: |
|
plot_image(np.log10(data+1), scale) |
|
|
|
with col2: |
|
if st.button('Detect cavities'): |
|
data, wcs = cut(data, wcs, scale=scale) |
|
|
|
image = np.log10(data+1) |
|
|
|
y_pred = 0 |
|
for j in [0,1,2,3]: |
|
rotated = np.rot90(image, j) |
|
pred = model.predict(rotated.reshape(1, 128, 128, 1)).reshape(128 ,128) |
|
pred = np.rot90(pred, -j) |
|
y_pred += pred / 4 |
|
|
|
|
|
y_pred = np.where(y_pred > 0.4, y_pred, 0) |
|
|
|
plot_prediction(y_pred) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|