|
|
|
from huggingface_hub import from_pretrained_keras |
|
model = from_pretrained_keras("Plsek/CADET-v1") |
|
|
|
|
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from matplotlib.patches import Rectangle |
|
|
|
|
|
from astropy.io import fits |
|
from astropy.wcs import WCS |
|
from astropy.nddata import Cutout2D, CCDData |
|
from astropy.convolution import Gaussian2DKernel as Gauss |
|
from astropy.convolution import convolve |
|
|
|
|
|
import streamlit as st |
|
st.set_option('deprecation.showPyplotGlobalUse', False) |
|
|
|
st.set_page_config(page_title="Cavity Detection Tool", layout="wide") |
|
|
|
|
|
|
|
bordersize = 0.6 |
|
_, col, _ = st.columns([bordersize, 3, bordersize]) |
|
|
|
with col: |
|
st.markdown("# Cavity Detection Tool") |
|
|
|
st.markdown("Cavity Detection Tool (CADET) is a machine learning pipeline trained to detect X-ray cavities from noisy Chandra images of early-type galaxies. To use this tool: upload your image, select the scale of interest, and make a prediction! If you use this tool for your research, please cite [Plšek et al. 2023](https://arxiv.org/abs/2304.05457)") |
|
|
|
st.markdown("Input images should be centered at the centre of the galaxy and point sources should be filled with surrounding background ([dmfilth](https://cxc.cfa.harvard.edu/ciao/ahelp/dmfilth.html)).") |
|
|
|
|
|
uploaded_file = st.file_uploader("Choose a FITS file", type=['fits']) |
|
|
|
|
|
def plot_image(image, scale): |
|
plt.figure(figsize=(4, 4)) |
|
|
|
x0 = image.shape[0] // 2 - scale * 128 / 2 |
|
plt.imshow(image, origin="lower") |
|
plt.gca().add_patch(Rectangle((x0, x0), scale*128, scale*128, linewidth=1, edgecolor='w', facecolor='none')) |
|
|
|
plt.axis('off') |
|
with colA: st.pyplot() |
|
|
|
|
|
def plot_prediction(pred): |
|
plt.figure(figsize=(4, 4)) |
|
plt.imshow(pred, origin="lower") |
|
plt.axis('off') |
|
with colB: st.pyplot() |
|
|
|
|
|
def cut(data0, wcs0, scale=1): |
|
shape = data0.shape[0] |
|
x0 = shape / 2 |
|
size = 128 * scale |
|
cutout = Cutout2D(data0, (x0, x0), (size, size), wcs=wcs0) |
|
data, wcs = cutout.data, cutout.wcs |
|
|
|
|
|
factor = size // 128 |
|
data = data.reshape(128, factor, 128, factor).mean(-1).mean(1) |
|
|
|
|
|
ra, dec = wcs.wcs_pix2world(np.array([[63, 63]]),0)[0] |
|
wcs.wcs.cdelt[0] = wcs.wcs.cdelt[0] * factor |
|
wcs.wcs.cdelt[1] = wcs.wcs.cdelt[1] * factor |
|
wcs.wcs.crval[0] = ra |
|
wcs.wcs.crval[1] = dec |
|
wcs.wcs.crpix[0] = 64 / factor |
|
wcs.wcs.crpix[1] = 64 / factor |
|
|
|
return data, wcs |
|
|
|
def decompose_cavity(pred, th2=0.7, amin=10): |
|
X, Y = pred.nonzero() |
|
data = np.array([X,Y]).reshape(2, -1) |
|
|
|
|
|
try: clusters = DBSCAN(eps=1.5, min_samples=3).fit(data.T).labels_ |
|
except: clusters = [] |
|
|
|
N = len(set(clusters)) |
|
cavities = [] |
|
|
|
for i in range(N): |
|
img = np.zeros((128,128)) |
|
b = clusters == i |
|
xi, yi = X[b], Y[b] |
|
img[xi, yi] = pred[xi, yi] |
|
|
|
|
|
if not (img > th2).any(): continue |
|
|
|
|
|
if np.sum(img) <= amin: continue |
|
|
|
cavities.append(img) |
|
|
|
return cavities |
|
|
|
|
|
if uploaded_file is not None: |
|
with fits.open(uploaded_file) as hdul: |
|
data = hdul[0].data |
|
wcs = WCS(hdul[0].header) |
|
|
|
|
|
_, col1, col2, col3, col4, col5, col6, _ = st.columns([bordersize,0.5,0.5,0.5,0.5,0.5,0.5,bordersize]) |
|
col1.subheader("Input image") |
|
col3.subheader("Prediction") |
|
col5.subheader("Decomposed") |
|
|
|
with col2: |
|
|
|
max_scale = int(data.shape[0] // 128) |
|
|
|
scale = st.selectbox('Scale:',[f"{(i+1)*128}x{(i+1)*128}" for i in range(max_scale)], label_visibility="hidden") |
|
scale = int(scale.split("x")[0]) // 128 |
|
|
|
with col4: |
|
|
|
detect = st.button('Detect') |
|
|
|
with col6: |
|
decompose = st.button('Docompose') |
|
|
|
|
|
_, colA, colB, colC, _ = st.columns([bordersize,1,1,1,bordersize]) |
|
|
|
image = np.log10(data+1) |
|
plot_image(image, scale) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|