File size: 8,129 Bytes
0012f0e 57f9dda 0ea9e86 7454012 e52c641 2bb6d79 6414f94 b4f54b4 6414f94 8111624 2680aa6 7454012 6414f94 3190548 d01a85f 7454012 1197a1f 7454012 6414f94 d01a85f ddce24a d01a85f 2680aa6 3190548 32d9557 3190548 8120054 c1f2126 6414f94 32d9557 6414f94 29cd1e9 47ec853 29cd1e9 d0a6846 8111624 7cb9bcf 8111624 48b10cd 8111624 29cd1e9 0755e66 6414f94 d26c581 7454012 d26c581 7454012 d26c581 44cb863 6414f94 0755e66 5d6cf3b 9219e24 a8b8828 4161769 a8b8828 0755e66 38ebf86 15c5050 0755e66 7e4e8dd 15c5050 0755e66 7c32f83 15c5050 0755e66 7c32f83 15c5050 0755e66 393d78f 17b9e5d 6b286d7 393d78f e9ef2e7 0755e66 e9ef2e7 cb61555 e9ef2e7 0755e66 e9ef2e7 992ff70 17c52a0 992ff70 0012f0e e9ef2e7 992ff70 0012f0e b046ded b635f79 6414f94 9c67ee7 8516575 36e4579 125c6bf 58a68f2 c6b81d0 0e52fa4 36a5e8a b046ded 125c6bf 17c52a0 bb17b88 8e1fb51 0755e66 17b9e5d 0755e66 c6b81d0 0755e66 17c52a0 0755e66 e842a8b 15c5050 bb17b88 a18eaf5 58a68f2 bb17b88 32d9557 a8b8828 e3ab055 4161769 7bcd4a9 9219e24 c6b81d0 d61ba29 393d78f 9219e24 d61ba29 091111e d61ba29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 |
# HuggingFace Hub
from huggingface_hub import from_pretrained_keras
model = from_pretrained_keras("Plsek/CADET-v1")
# Basic libraries
import os
import shutil
import numpy as np
from scipy.ndimage import center_of_mass
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from matplotlib.patches import Rectangle
# Astropy
from astropy.io import fits
from astropy.wcs import WCS
from astropy.nddata import Cutout2D, CCDData
from astropy.convolution import Gaussian2DKernel as Gauss
from astropy.convolution import convolve
# Scikit-learn
from sklearn.cluster import DBSCAN
# Streamlit
import streamlit as st
st.set_option('deprecation.showPyplotGlobalUse', False)
# Define function to plot the uploaded image
def plot_image(image, scale):
plt.figure(figsize=(4, 4))
x0 = image.shape[0] // 2 - scale * 128 / 2
plt.imshow(image, origin="lower")
plt.gca().add_patch(Rectangle((x0, x0), scale*128, scale*128, linewidth=1, edgecolor='w', facecolor='none'))
plt.axis('off')
with colA: st.pyplot()
# Define function to plot the prediction
def plot_prediction(pred):
plt.figure(figsize=(4, 4))
plt.imshow(pred, origin="lower")
plt.axis('off')
with colB: st.pyplot()
# Define function to plot the decomposed prediction
def plot_decomposed(decomposed):
plt.figure(figsize=(4, 4))
plt.imshow(decomposed, origin="lower") #, norm=LogNorm())
N = int(np.max(decomposed))
for i in range(N):
new = np.where(decomposed == i+1, 1, 0)
x0, y0 = center_of_mass(new)
color = "white" if i < N//2 else "black"
plt.text(y0, x0, f"{i+1}", ha="center", va="center", fontsize=15, color=color)
plt.axis('off')
with colC: st.pyplot()
# Define function to cut input image and rebin it to 128x128 pixels
def cut(data0, wcs0, scale=1):
shape = data0.shape[0]
x0 = shape / 2
size = 128 * scale
cutout = Cutout2D(data0, (x0, x0), (size, size), wcs=wcs0)
data, wcs = cutout.data, cutout.wcs
# Regrid data
factor = size // 128
data = data.reshape(128, factor, 128, factor).mean(-1).mean(1)
# Regrid wcs
ra, dec = wcs.wcs_pix2world(np.array([[63, 63]]),0)[0]
wcs.wcs.cdelt[0] = wcs.wcs.cdelt[0] * factor
wcs.wcs.cdelt[1] = wcs.wcs.cdelt[1] * factor
wcs.wcs.crval[0] = ra
wcs.wcs.crval[1] = dec
wcs.wcs.crpix[0] = 64 / factor
wcs.wcs.crpix[1] = 64 / factor
return data, wcs
# Define function to apply cutting and produce a prediction
@st.cache
def cut_n_predict(data, wcs, scale):
data, wcs = cut(data, wcs, scale=scale)
image = np.log10(data+1)
y_pred = 0
for j in [0,1,2,3]:
rotated = np.rot90(image, j)
pred = model.predict(rotated.reshape(1, 128, 128, 1)).reshape(128 ,128)
pred = np.rot90(pred, -j)
y_pred += pred / 4
return y_pred, wcs
# Define function to decompose prediction into individual cavities
@st.cache
def decompose_cavity(pred, th2=0.7, amin=10):
X, Y = pred.nonzero()
data = np.array([X,Y]).reshape(2, -1)
# DBSCAN clustering
try: clusters = DBSCAN(eps=1.0, min_samples=3).fit(data.T).labels_
except: clusters = []
N = len(set(clusters))
cavities = []
for i in range(N):
img = np.zeros((128,128))
b = clusters == i
xi, yi = X[b], Y[b]
img[xi, yi] = pred[xi, yi]
# Thresholding #2
if not (img > th2).any(): continue
# Minimal area
if np.sum(img) <= amin: continue
cavities.append(img)
# Save raw and decomposed predictions to predictions folder
ccd = CCDData(pred, unit="adu", wcs=wcs)
ccd.write(f"predictions/predicted.fits", overwrite=True)
image_decomposed = np.zeros((128,128))
for i, cav in enumerate(cavities):
ccd = CCDData(cav, unit="adu", wcs=wcs)
ccd.write(f"predictions/predicted_{i+1}.fits", overwrite=True)
image_decomposed += (i+1) * np.where(cav > 0, 1, 0)
# shutil.make_archive("predictions", 'zip', "predictions")
return image_decomposed
# Use wide layout and create columns
st.set_page_config(page_title="Cavity Detection Tool", layout="wide")
bordersize = 0.6
_, col, _ = st.columns([bordersize, 3, bordersize])
os.system("mkdir -p predictions")
with col:
# Create heading and description
st.markdown("# Cavity Detection Tool")
st.markdown("Cavity Detection Tool (CADET) is a machine learning pipeline trained to detect X-ray cavities from noisy Chandra images of early-type galaxies.")
st.markdown("To use this tool: upload your image, select the scale of interest, and make a prediction!")
st.markdown("Input images should be centred at the centre of the galaxy and point sources should be filled with surrounding background ([dmfilth](https://cxc.cfa.harvard.edu/ciao/ahelp/dmfilth.html)).")
st.markdown("If you use this tool for your research, please cite [Plšek et al. 2023](https://arxiv.org/abs/2304.05457)")
_, col_1, _, col_2, _ = st.columns([bordersize, 2.0, 0.5, 0.5, bordersize])
with col_1:
# Create file uploader widget
uploaded_file = st.file_uploader("Choose a FITS file", type=['fits'])
with col_2:
st.markdown("# Examples")
NGC4649 = st.button("NGC4649")
NGC5813 = st.button("NGC5813")
if NGC4649: uploaded_file = "NGC4649_example.fits"
elif NGC5813: uploaded_file = "NGC5813_example.fits"
# If file is uploaded, read in the data and plot it
if uploaded_file is not None:
with fits.open(uploaded_file) as hdul:
data = hdul[0].data
wcs = WCS(hdul[0].header)
# Make six columns for buttons
_, col1, col2, col3, col4, col5, col6, _ = st.columns([bordersize,0.5,0.5,0.5,0.5,0.5,0.5,bordersize])
col1.subheader("Input image")
col3.subheader("Prediction")
col5.subheader("Decomposed")
col6.subheader("")
with col1:
st.markdown("""<style>[data-baseweb="select"] {margin-top: -46px;}</style>""", unsafe_allow_html=True)
max_scale = int(data.shape[0] // 128)
scale = st.selectbox('Scale:',[f"{(i+1)*128}x{(i+1)*128}" for i in range(max_scale)], label_visibility="hidden")
scale = int(scale.split("x")[0]) // 128
# Detect button
with col3: detect = st.button('Detect', key="detect")
# Threshold slider
with col4:
st.markdown("")
# st.markdown("""<style>[data-baseweb="select"] {margin-top: -36px;}</style>""", unsafe_allow_html=True)
threshold = st.slider("Threshold", 0.0, 1.0, 0.0, 0.05) #, label_visibility="hidden")
# Decompose button
with col5: decompose = st.button('Decompose', key="decompose")
# Make two columns for plots
_, colA, colB, colC, _ = st.columns([bordersize,1,1,1,bordersize])
image = np.log10(data+1)
plot_image(image, scale)
if detect or threshold:
y_pred, wcs = cut_n_predict(data, wcs, scale)
y_pred_th = np.where(y_pred > threshold, y_pred, 0)
plot_prediction(y_pred_th)
if decompose or st.session_state.get("download", False):
image_decomposed = decompose_cavity(y_pred_th)
plot_decomposed(image_decomposed)
with col6:
st.markdown("<br style='margin:4px 0'>", unsafe_allow_html=True)
# st.markdown("""<style>[data-baseweb="select"] {margin-top: 16px;}</style>""", unsafe_allow_html=True)
fname = uploaded_file.name.strip(".fits")
# if st.session_state.get("download", False):
shutil.make_archive("predictions", 'zip', "predictions")
with open('predictions.zip', 'rb') as f:
res = f.read()
download = st.download_button(label="Download", data=res, key="download",
file_name=f'{fname}_{int(scale*128)}.zip',
# disabled=st.session_state.get("disabled", True),
mime="application/octet-stream") |