File size: 4,428 Bytes
7454012 6414f94 2680aa6 7454012 6414f94 3190548 d01a85f 7454012 248cf5e 31764f2 6414f94 7454012 6414f94 7454012 d26c581 6414f94 ff9f313 c1f2126 6414f94 d01a85f ddce24a c1f2126 d01a85f 2680aa6 c1f2126 3190548 0e52fa4 3190548 d01a85f 3190548 8120054 c1f2126 6414f94 0e52fa4 6414f94 7454012 6414f94 d26c581 7454012 d26c581 7454012 d26c581 44cb863 6414f94 b635f79 6414f94 9c67ee7 0ded826 0e52fa4 9c67ee7 778987b e5f809c 7454012 0e52fa4 0ded826 0a4f4ea 0e52fa4 05f58f4 8e1fb51 0ded826 0e52fa4 d01a85f 0e52fa4 0a4f4ea e5f809c 0a4f4ea d01a85f 0a4f4ea 16cf1a4 0ded826 fb43d84 d01a85f 0ded826 5bbb151 0ded826 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
# Basic libraries
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
# Astropy
from astropy.io import fits
from astropy.wcs import WCS
from astropy.nddata import Cutout2D, CCDData
from astropy.convolution import Gaussian2DKernel as Gauss
from astropy.convolution import convolve
# HuggingFace
from huggingface_hub import from_pretrained_keras
model = from_pretrained_keras("Plsek/CADET-v1")
# Streamlit
import streamlit as st
st.set_option('deprecation.showPyplotGlobalUse', False)
st.title("Cavity Detection Tool")
st.markdown("Cavity Detection Tool (CADET) is a machine learning pipeline trained to detect X-ray cavities from noisy Chandra images of early-type galaxies. To use this tool: upload your image, select the scale of interest, and make a prediction! If you use the output of this tool in your research please cite [Plšek et al. 2023](https://arxiv.org/abs/2304.05457)")
# Create file uploader widget
uploaded_file = st.file_uploader("Choose a FITS file", type=['fits'])
# Define function to plot the uploaded image
def plot_image(image, scale):
plt.figure(figsize=(4, 4))
x0 = image.shape[0] // 2 - scale * 128 / 2
plt.imshow(image, origin="lower")
plt.gca().add_patch(Rectangle((x0, x0), scale*128, scale*128, linewidth=1, edgecolor='w', facecolor='none'))
plt.axis('off')
with colA: st.pyplot()
# Define function to smooth image
def smooth_image(image, scale):
smoothed = convolve(image, boundary = "wrap", nan_treatment="interpolate",
kernel = Gauss(x_stddev = 2, y_stddev = 2))
return smoothed
# Define function to plot the prediction
def plot_prediction(pred):
plt.figure(figsize=(4, 4))
plt.imshow(pred, origin="lower")
plt.axis('off')
with colB: st.pyplot()
# Cut input image and rebin it to 128x128 pixels
def cut(data0, wcs0, scale=1):
shape = data0.shape[0]
x0 = shape / 2
size = 128 * scale
cutout = Cutout2D(data0, (x0, x0), (size, size), wcs=wcs0)
data, wcs = cutout.data, cutout.wcs
# Regrid data
factor = size // 128
data = data.reshape(128, factor, 128, factor).mean(-1).mean(1)
# Regrid wcs
ra, dec = wcs.wcs_pix2world(np.array([[63, 63]]),0)[0]
wcs.wcs.cdelt[0] = wcs.wcs.cdelt[0] * factor
wcs.wcs.cdelt[1] = wcs.wcs.cdelt[1] * factor
wcs.wcs.crval[0] = ra
wcs.wcs.crval[1] = dec
wcs.wcs.crpix[0] = 64 / factor
wcs.wcs.crpix[1] = 64 / factor
return data, wcs
# If file is uploaded, read in the data and plot it
if uploaded_file is not None:
with fits.open(uploaded_file) as hdul:
data = hdul[0].data
wcs = WCS(hdul[0].header)
# Make four columns for buttons
col1, col2, col3, col4 = st.columns(4)
col1.subheader("Input image")
col3.subheader("Prediction")
# Add a slider to change the scale
with col1:
smooth = st.button("Smooth")
with col2:
st.markdown("""<style>[data-baseweb="select"] {margin-top: 17px;}</style>""", unsafe_allow_html=True)
max_scale = int(data.shape[0] // 128)
scale = int(st.selectbox('Scale:',[i+1 for i in range(max_scale)], label_visibility="hidden"))
with col3:
detect = st.button('Detect cavities')
# Make two columns for plots
colA, colB = st.columns(2)
image = np.log10(data+1)
if smooth: image = smooth_image(image, scale)
plot_image(image, scale)
if detect:
data, wcs = cut(data, wcs, scale=scale)
image = np.log10(data+1)
y_pred = 0
for j in [0,1,2,3]:
rotated = np.rot90(image, j)
pred = model.predict(rotated.reshape(1, 128, 128, 1)).reshape(128 ,128)
pred = np.rot90(pred, -j)
y_pred += pred / 4
# Thresholding
y_pred = np.where(y_pred > 0.4, y_pred, 0)
plot_prediction(y_pred)
ccd = CCDData(y_pred, unit="adu", wcs=wcs)
ccd.write("predicted.fits", overwrite=True)
with open('predicted.fits', 'rb') as f:
res = f.read()
with col4:
pass
# st.markdown("""<style>[data-baseweb="select"] {margin-top: 32px;}</style>""", unsafe_allow_html=True)
# # download = st.button('Download')
# download = st.download_button(label="Download", data=res, file_name="predicted.fits", mime="application/octet-stream") |