CADET / app.py
Plsek's picture
Update app.py
e52c641
raw
history blame
6.78 kB
# HuggingFace
from huggingface_hub import from_pretrained_keras
model = from_pretrained_keras("Plsek/CADET-v1")
# Basic libraries
import os
import shutil
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
# Astropy
from astropy.io import fits
from astropy.wcs import WCS
from astropy.nddata import Cutout2D, CCDData
from astropy.convolution import Gaussian2DKernel as Gauss
from astropy.convolution import convolve
# Scikit-learn
from sklearn.cluster import DBSCAN
# Streamlit
import streamlit as st
st.set_option('deprecation.showPyplotGlobalUse', False)
st.set_page_config(page_title="Cavity Detection Tool", layout="wide")
# st.title("Cavity Detection Tool")
bordersize = 0.6
_, col, _ = st.columns([bordersize, 3, bordersize])
os.system("mkdir predictions")
with col:
st.markdown("# Cavity Detection Tool")
st.markdown("Cavity Detection Tool (CADET) is a machine learning pipeline trained to detect X-ray cavities from noisy Chandra images of early-type galaxies.")
st.markdown("To use this tool: upload your image, select the scale of interest, and make a prediction!")
st.markdown("If you use this tool for your research, please cite [Plšek et al. 2023](https://arxiv.org/abs/2304.05457)")
st.markdown("Input images should be centered at the centre of the galaxy and point sources should be filled with surrounding background ([dmfilth](https://cxc.cfa.harvard.edu/ciao/ahelp/dmfilth.html)).")
# Create file uploader widget
uploaded_file = st.file_uploader("Choose a FITS file", type=['fits'])
# Define function to plot the uploaded image
def plot_image(image, scale):
plt.figure(figsize=(4, 4))
x0 = image.shape[0] // 2 - scale * 128 / 2
plt.imshow(image, origin="lower")
plt.gca().add_patch(Rectangle((x0, x0), scale*128, scale*128, linewidth=1, edgecolor='w', facecolor='none'))
plt.axis('off')
with colA: st.pyplot()
# Define function to plot the prediction
def plot_prediction(pred):
plt.figure(figsize=(4, 4))
plt.imshow(pred, origin="lower")
plt.axis('off')
with colB: st.pyplot()
# Define function to plot the decomposed prediction
def plot_decomposed(pred):
plt.figure(figsize=(4, 4))
plt.imshow(pred, origin="lower")
plt.axis('off')
with colC: st.pyplot()
# Cut input image and rebin it to 128x128 pixels
def cut(data0, wcs0, scale=1):
shape = data0.shape[0]
x0 = shape / 2
size = 128 * scale
cutout = Cutout2D(data0, (x0, x0), (size, size), wcs=wcs0)
data, wcs = cutout.data, cutout.wcs
# Regrid data
factor = size // 128
data = data.reshape(128, factor, 128, factor).mean(-1).mean(1)
# Regrid wcs
ra, dec = wcs.wcs_pix2world(np.array([[63, 63]]),0)[0]
wcs.wcs.cdelt[0] = wcs.wcs.cdelt[0] * factor
wcs.wcs.cdelt[1] = wcs.wcs.cdelt[1] * factor
wcs.wcs.crval[0] = ra
wcs.wcs.crval[1] = dec
wcs.wcs.crpix[0] = 64 / factor
wcs.wcs.crpix[1] = 64 / factor
return data, wcs
def decompose_cavity(pred, th2=0.7, amin=10):
X, Y = pred.nonzero()
data = np.array([X,Y]).reshape(2, -1)
# DBSCAN CLUSTERING ALGORITHM
try: clusters = DBSCAN(eps=1.5, min_samples=3).fit(data.T).labels_
except: clusters = []
N = len(set(clusters))
cavities = []
for i in range(N):
img = np.zeros((128,128))
b = clusters == i
xi, yi = X[b], Y[b]
img[xi, yi] = pred[xi, yi]
# THRESHOLDING #2
if not (img > th2).any(): continue
# MINIMAL AREA
if np.sum(img) <= amin: continue
cavities.append(img)
return cavities
# If file is uploaded, read in the data and plot it
if uploaded_file is not None:
with fits.open(uploaded_file) as hdul:
data = hdul[0].data
wcs = WCS(hdul[0].header)
y_pred = np.zeros((128,128))
# Make four columns for buttons
_, col1, col2, col3, col4, col5, col6, _ = st.columns([bordersize,0.5,0.5,0.5,0.5,0.5,0.5,bordersize])
col1.subheader("Input image")
col3.subheader("Prediction")
col5.subheader("Decomposed")
with col1:
# st.markdown("""<style>[data-baseweb="select"] {margin-top: -26px;}</style>""", unsafe_allow_html=True)
max_scale = int(data.shape[0] // 128)
scale = st.selectbox('Scale:',[f"{(i+1)*128}x{(i+1)*128}" for i in range(max_scale)], label_visibility="hidden")
scale = int(scale.split("x")[0]) // 128
# np.save("pred.npy", y_pred)
with col3:
detect = st.button('Detect')
with col5:
decompose = st.button('Docompose')
# Make two columns for plots
_, colA, colB, colC, _ = st.columns([bordersize,1,1,1,bordersize])
image = np.log10(data+1)
plot_image(image, scale)
with col4:
st.markdown("""<style>[data-baseweb="select"] {margin-top: -36px;}</style>""", unsafe_allow_html=True)
threshold = st.slider("", 0.0, 1.0, 0.0, 0.05, label_visibility="hidden")
if detect:
data, wcs = cut(data, wcs, scale=scale)
image = np.log10(data+1)
y_pred = 0
for j in [0,1,2,3]:
rotated = np.rot90(image, j)
pred = model.predict(rotated.reshape(1, 128, 128, 1)).reshape(128 ,128)
pred = np.rot90(pred, -j)
y_pred += pred / 4
np.save("pred.npy", y_pred)
try: y_pred = np.load("pred.npy")
except: y_pred = np.zeros((128,128))
y_pred = np.where(y_pred > threshold, y_pred, 0)
np.save("thresh.npy", y_pred)
plot_prediction(y_pred)
# with colC:
# st.markdown("""<style>[data-baseweb="select"] {margin-top: -36px;}</style>""", unsafe_allow_html=True)
if decompose:
y_pred = np.load("thresh.npy")
cavs = decompose_cavity(y_pred)
ccd = CCDData(y_pred, unit="adu", wcs=wcs)
ccd.write(f"predicted.fits", overwrite=True)
image_decomposed = np.zeros((128,128))
for i, cav in enumerate(cavs):
ccd = CCDData(cav, unit="adu", wcs=wcs)
ccd.write(f"predicted_{i+1}.fits", overwrite=True)
image_decomposed += (i+1) * np.where(cav > 0, 1, 0)
shutil.make_archive("predictions", 'zip', "predictions")
np.save("decomposed.npy", image_decomposed)
try: image_decomposed = np.load("decomposed.npy")
except: image_decomposed = np.zeros((128,128))
plot_decomposed(image_decomposed)
with col6:
pass
# st.markdown("""<style>[data-baseweb="select"] {margin-top: 16px;}</style>""", unsafe_allow_html=True)
# # download = st.button('Download')
download = st.download_button(label="Download", data=res, file_name="predicted.zip", mime="application/octet-stream")