|  |  | 
					
						
						|  | from huggingface_hub import from_pretrained_keras | 
					
						
						|  | model = from_pretrained_keras("Plsek/CADET-v1") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | import os | 
					
						
						|  | import shutil | 
					
						
						|  | import numpy as np | 
					
						
						|  | import matplotlib.pyplot as plt | 
					
						
						|  | from matplotlib.patches import Rectangle | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | from astropy.io import fits | 
					
						
						|  | from astropy.wcs import WCS | 
					
						
						|  | from astropy.nddata import Cutout2D, CCDData | 
					
						
						|  | from astropy.convolution import Gaussian2DKernel as Gauss | 
					
						
						|  | from astropy.convolution import convolve | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | from sklearn.cluster import DBSCAN | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | import streamlit as st | 
					
						
						|  | st.set_option('deprecation.showPyplotGlobalUse', False) | 
					
						
						|  | st.set_page_config(page_title="Cavity Detection Tool", layout="wide") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | bordersize = 0.6 | 
					
						
						|  | _, col, _ = st.columns([bordersize, 3, bordersize]) | 
					
						
						|  |  | 
					
						
						|  | os.system("mkdir predictions") | 
					
						
						|  |  | 
					
						
						|  | with col: | 
					
						
						|  | st.markdown("# Cavity Detection Tool") | 
					
						
						|  |  | 
					
						
						|  | st.markdown("Cavity Detection Tool (CADET) is a machine learning pipeline trained to detect X-ray cavities from noisy Chandra images of early-type galaxies.") | 
					
						
						|  | st.markdown("To use this tool: upload your image, select the scale of interest, and make a prediction!") | 
					
						
						|  | st.markdown("If you use this tool for your research, please cite [Plšek et al. 2023](https://arxiv.org/abs/2304.05457)") | 
					
						
						|  | st.markdown("Input images should be centered at the centre of the galaxy and point sources should be filled with surrounding background ([dmfilth](https://cxc.cfa.harvard.edu/ciao/ahelp/dmfilth.html)).") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | uploaded_file = st.file_uploader("Choose a FITS file", type=['fits']) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def plot_image(image, scale): | 
					
						
						|  | plt.figure(figsize=(4, 4)) | 
					
						
						|  |  | 
					
						
						|  | x0 = image.shape[0] // 2 - scale * 128 / 2 | 
					
						
						|  | plt.imshow(image, origin="lower") | 
					
						
						|  | plt.gca().add_patch(Rectangle((x0, x0), scale*128, scale*128, linewidth=1, edgecolor='w', facecolor='none')) | 
					
						
						|  |  | 
					
						
						|  | plt.axis('off') | 
					
						
						|  | with colA: st.pyplot() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def plot_prediction(pred): | 
					
						
						|  | plt.figure(figsize=(4, 4)) | 
					
						
						|  | plt.imshow(pred, origin="lower") | 
					
						
						|  | plt.axis('off') | 
					
						
						|  | with colB: st.pyplot() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def plot_decomposed(pred): | 
					
						
						|  | plt.figure(figsize=(4, 4)) | 
					
						
						|  | plt.imshow(pred, origin="lower") | 
					
						
						|  | plt.axis('off') | 
					
						
						|  | with colC: st.pyplot() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def cut(data0, wcs0, scale=1): | 
					
						
						|  | shape = data0.shape[0] | 
					
						
						|  | x0 = shape / 2 | 
					
						
						|  | size = 128 * scale | 
					
						
						|  | cutout = Cutout2D(data0, (x0, x0), (size, size), wcs=wcs0) | 
					
						
						|  | data, wcs = cutout.data, cutout.wcs | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | factor = size // 128 | 
					
						
						|  | data = data.reshape(128, factor, 128, factor).mean(-1).mean(1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | ra, dec = wcs.wcs_pix2world(np.array([[63, 63]]),0)[0] | 
					
						
						|  | wcs.wcs.cdelt[0] = wcs.wcs.cdelt[0] * factor | 
					
						
						|  | wcs.wcs.cdelt[1] = wcs.wcs.cdelt[1] * factor | 
					
						
						|  | wcs.wcs.crval[0] = ra | 
					
						
						|  | wcs.wcs.crval[1] = dec | 
					
						
						|  | wcs.wcs.crpix[0] = 64 / factor | 
					
						
						|  | wcs.wcs.crpix[1] = 64 / factor | 
					
						
						|  |  | 
					
						
						|  | return data, wcs | 
					
						
						|  |  | 
					
						
						|  | def decompose_cavity(pred, th2=0.7, amin=10): | 
					
						
						|  | X, Y = pred.nonzero() | 
					
						
						|  | data = np.array([X,Y]).reshape(2, -1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | try: clusters = DBSCAN(eps=1.5, min_samples=3).fit(data.T).labels_ | 
					
						
						|  | except: clusters = [] | 
					
						
						|  |  | 
					
						
						|  | N = len(set(clusters)) | 
					
						
						|  | cavities = [] | 
					
						
						|  |  | 
					
						
						|  | for i in range(N): | 
					
						
						|  | img = np.zeros((128,128)) | 
					
						
						|  | b = clusters == i | 
					
						
						|  | xi, yi = X[b], Y[b] | 
					
						
						|  | img[xi, yi] = pred[xi, yi] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if not (img > th2).any(): continue | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if np.sum(img) <= amin: continue | 
					
						
						|  |  | 
					
						
						|  | cavities.append(img) | 
					
						
						|  |  | 
					
						
						|  | return cavities | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if uploaded_file is not None: | 
					
						
						|  | with fits.open(uploaded_file) as hdul: | 
					
						
						|  | data = hdul[0].data | 
					
						
						|  | wcs = WCS(hdul[0].header) | 
					
						
						|  | y_pred = np.zeros((128,128)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | _, col1, col2, col3, col4, col5, col6, _ = st.columns([bordersize,0.5,0.5,0.5,0.5,0.5,0.5,bordersize]) | 
					
						
						|  | col1.subheader("Input image") | 
					
						
						|  | col3.subheader("Prediction") | 
					
						
						|  | col5.subheader("Decomposed") | 
					
						
						|  |  | 
					
						
						|  | with col1: | 
					
						
						|  |  | 
					
						
						|  | max_scale = int(data.shape[0] // 128) | 
					
						
						|  | scale = st.selectbox('Scale:',[f"{(i+1)*128}x{(i+1)*128}" for i in range(max_scale)], label_visibility="hidden") | 
					
						
						|  | scale = int(scale.split("x")[0]) // 128 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | with col3: | 
					
						
						|  | detect = st.button('Detect') | 
					
						
						|  |  | 
					
						
						|  | with col5: | 
					
						
						|  | decompose = st.button('Docompose') | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | _, colA, colB, colC, _ = st.columns([bordersize,1,1,1,bordersize]) | 
					
						
						|  |  | 
					
						
						|  | image = np.log10(data+1) | 
					
						
						|  | plot_image(image, scale) | 
					
						
						|  |  | 
					
						
						|  | with col4: | 
					
						
						|  | st.markdown("""<style>[data-baseweb="select"] {margin-top: -36px;}</style>""", unsafe_allow_html=True) | 
					
						
						|  | threshold = st.slider("", 0.0, 1.0, 0.0, 0.05, label_visibility="hidden") | 
					
						
						|  |  | 
					
						
						|  | if detect: | 
					
						
						|  | data, wcs = cut(data, wcs, scale=scale) | 
					
						
						|  | image = np.log10(data+1) | 
					
						
						|  |  | 
					
						
						|  | y_pred = 0 | 
					
						
						|  | for j in [0,1,2,3]: | 
					
						
						|  | rotated = np.rot90(image, j) | 
					
						
						|  | pred = model.predict(rotated.reshape(1, 128, 128, 1)).reshape(128 ,128) | 
					
						
						|  | pred = np.rot90(pred, -j) | 
					
						
						|  | y_pred += pred / 4 | 
					
						
						|  |  | 
					
						
						|  | np.save("pred.npy", y_pred) | 
					
						
						|  |  | 
					
						
						|  | try: y_pred = np.load("pred.npy") | 
					
						
						|  | except: y_pred = np.zeros((128,128)) | 
					
						
						|  | y_pred = np.where(y_pred > threshold, y_pred, 0) | 
					
						
						|  | np.save("thresh.npy", y_pred) | 
					
						
						|  |  | 
					
						
						|  | plot_prediction(y_pred) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if decompose: | 
					
						
						|  | y_pred = np.load("thresh.npy") | 
					
						
						|  |  | 
					
						
						|  | cavs = decompose_cavity(y_pred) | 
					
						
						|  |  | 
					
						
						|  | ccd = CCDData(y_pred, unit="adu", wcs=wcs) | 
					
						
						|  | ccd.write(f"predicted.fits", overwrite=True) | 
					
						
						|  | image_decomposed = np.zeros((128,128)) | 
					
						
						|  | for i, cav in enumerate(cavs): | 
					
						
						|  | ccd = CCDData(cav, unit="adu", wcs=wcs) | 
					
						
						|  | ccd.write(f"predicted_{i+1}.fits", overwrite=True) | 
					
						
						|  | image_decomposed += (i+1) * np.where(cav > 0, 1, 0) | 
					
						
						|  |  | 
					
						
						|  | shutil.make_archive("predictions", 'zip', "predictions") | 
					
						
						|  | np.save("decomposed.npy", image_decomposed) | 
					
						
						|  |  | 
					
						
						|  | try: image_decomposed = np.load("decomposed.npy") | 
					
						
						|  | except: image_decomposed = np.zeros((128,128)) | 
					
						
						|  | plot_decomposed(image_decomposed) | 
					
						
						|  |  | 
					
						
						|  | with col6: | 
					
						
						|  | pass | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | download = st.download_button(label="Download", data=res, file_name="predicted.zip", mime="application/octet-stream") |