File size: 9,041 Bytes
0012f0e 57f9dda 0ea9e86 7454012 e52c641 2bb6d79 6414f94 b4f54b4 6414f94 8111624 2680aa6 7454012 6414f94 3190548 d01a85f 7454012 1197a1f 7454012 6414f94 08bba4c 8111624 08bba4c 29cd1e9 08bba4c d26c581 08bba4c a8b8828 08bba4c 6b286d7 08bba4c e9ef2e7 08bba4c 00e1b14 08bba4c d083cdb 00e1b14 d083cdb 08bba4c d083cdb 08bba4c cb61555 08bba4c 992ff70 08bba4c 992ff70 08bba4c 992ff70 08bba4c 3ba01a0 08bba4c b046ded 08bba4c d0ce4b5 08bba4c d083cdb 08bba4c 60663d5 08bba4c 0e52fa4 08bba4c cd05002 08bba4c cd05002 08bba4c 60663d5 08bba4c 60663d5 08bba4c 60663d5 08bba4c d083cdb 08bba4c cd05002 08bba4c d083cdb 08bba4c d083cdb 08bba4c d083cdb 08bba4c d083cdb 08bba4c 58946f4 08bba4c 2af3442 08bba4c d083cdb 08bba4c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 |
# HuggingFace Hub
from huggingface_hub import from_pretrained_keras
model = from_pretrained_keras("Plsek/CADET-v1")
# Basic libraries
import os
import shutil
import numpy as np
from scipy.ndimage import center_of_mass
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from matplotlib.patches import Rectangle
# Astropy
from astropy.io import fits
from astropy.wcs import WCS
from astropy.nddata import Cutout2D, CCDData
from astropy.convolution import Gaussian2DKernel as Gauss
from astropy.convolution import convolve
# Scikit-learn
from sklearn.cluster import DBSCAN
# Streamlit
import streamlit as st
st.set_option('deprecation.showPyplotGlobalUse', False)
# # Define function to plot the uploaded image
# def plot_image(image, scale):
# plt.figure(figsize=(4, 4))
# x0 = image.shape[0] // 2 - scale * 128 / 2
# plt.imshow(image, origin="lower")
# plt.gca().add_patch(Rectangle((x0-0.5, x0-0.5), scale*128, scale*128, linewidth=1, edgecolor='w', facecolor='none'))
# plt.axis('off')
# plt.tight_layout()
# with colA: st.pyplot()
# # Define function to plot the prediction
# def plot_prediction(pred):
# plt.figure(figsize=(4, 4))
# plt.imshow(pred, origin="lower")
# plt.axis('off')
# with colB: st.pyplot()
# # Define function to plot the decomposed prediction
# def plot_decomposed(decomposed):
# plt.figure(figsize=(4, 4))
# plt.imshow(decomposed, origin="lower") #, norm=LogNorm())
# N = int(np.max(decomposed))
# for i in range(N):
# new = np.where(decomposed == i+1, 1, 0)
# x0, y0 = center_of_mass(new)
# color = "white" if i < N//2 else "black"
# plt.text(y0, x0, f"{i+1}", ha="center", va="center", fontsize=15, color=color)
# plt.axis('off')
# with colC: st.pyplot()
# # Define function to cut input image and rebin it to 128x128 pixels
# def cut(data0, wcs0, scale=1):
# shape = data0.shape[0]
# x0 = shape / 2
# size = 128 * scale
# cutout = Cutout2D(data0, (x0, x0), (size, size), wcs=wcs0)
# data, wcs = cutout.data, cutout.wcs
# # Regrid data
# factor = size // 128
# data = data.reshape(128, factor, 128, factor).mean(-1).mean(1)
# # Regrid wcs
# ra, dec = wcs.wcs_pix2world(np.array([[63, 63]]),0)[0]
# wcs.wcs.cdelt[0] = wcs.wcs.cdelt[0] * factor
# wcs.wcs.cdelt[1] = wcs.wcs.cdelt[1] * factor
# wcs.wcs.crval[0] = ra
# wcs.wcs.crval[1] = dec
# wcs.wcs.crpix[0] = 64 / factor
# wcs.wcs.crpix[1] = 64 / factor
# return data, wcs
# # Define function to apply cutting and produce a prediction
# # @st.cache
# def cut_n_predict(data, wcs, scale):
# data, wcs = cut(data, wcs, scale=scale)
# image = np.log10(data+1)
# y_pred = 0
# for j in [0,1,2,3]:
# rotated = np.rot90(image, j)
# pred = model.predict(rotated.reshape(1, 128, 128, 1)).reshape(128 ,128)
# pred = np.rot90(pred, -j)
# y_pred += pred / 4
# return y_pred, wcs
# # Define function to decompose prediction into individual cavities
# # @st.cache
# def decompose_cavity(pred, th2=0.7, amin=6):
# X, Y = pred.nonzero()
# data = np.array([X,Y]).reshape(2, -1)
# # DBSCAN clustering
# try: clusters = DBSCAN(eps=1.0, min_samples=3).fit(data.T).labels_
# except: clusters = []
# N = len(set(clusters))
# cavities = []
# for i in range(N):
# img = np.zeros((128,128))
# b = clusters == i
# xi, yi = X[b], Y[b]
# img[xi, yi] = pred[xi, yi]
# # # Thresholding #2
# # if not (img > th2).any(): continue
# # Minimal area
# if np.sum(img) <= amin: continue
# cavities.append(img)
# # Save raw and decomposed predictions to predictions folder
# ccd = CCDData(pred, unit="adu", wcs=wcs)
# ccd.write(f"predictions/predicted.fits", overwrite=True)
# image_decomposed = np.zeros((128,128))
# for i, cav in enumerate(cavities):
# ccd = CCDData(cav, unit="adu", wcs=wcs)
# ccd.write(f"predictions/predicted_{i+1}.fits", overwrite=True)
# image_decomposed += (i+1) * np.where(cav > 0, 1, 0)
# # shutil.make_archive("predictions", 'zip', "predictions")
# return image_decomposed
# # @st.cache
# def load_file(fname):
# with fits.open(fname) as hdul:
# data = hdul[0].data
# wcs = WCS(hdul[0].header)
# return data, wcs
# def change_scale():
# del st.session_state["threshold"]
# # Use wide layout and create columns
# st.set_page_config(page_title="Cavity Detection Tool", layout="wide")
# bordersize = 0.45
# _, col, _ = st.columns([bordersize, 3, bordersize])
# os.system("mkdir -p predictions")
# with col:
# # Create heading and description
# st.markdown("<h1 align='center'>Cavity Detection Tool</h1>", unsafe_allow_html=True)
# st.markdown("Cavity Detection Tool (CADET) is a machine learning pipeline trained to detect X-ray cavities from noisy Chandra images of early-type galaxies.")
# st.markdown("To use this tool: upload your image, select the scale of interest, make a prediction, and decompose it into individual cavities!")
# st.markdown("Input images should be in units of counts, centred at the galaxy center, and point sources should be filled with surrounding background ([dmfilth](https://cxc.cfa.harvard.edu/ciao/ahelp/dmfilth.html)).")
# st.markdown("If you use this tool for your research, please cite [Plšek et al. 2023](https://arxiv.org/abs/2304.05457)")
# # _, col_1, col_2, col_3, _ = st.columns([bordersize, 2.0, 0.5, 0.5, bordersize])
# # with col:
# uploaded_file = st.file_uploader("Choose a FITS file", type=['fits'])
# # with col_2:
# # st.markdown("### Examples")
# # NGC4649 = st.button("NGC4649")
# # with col_3:
# # st.markdown("""<style>[data-baseweb="select"] {margin-top: 26px;}</style>""", unsafe_allow_html=True)
# # NGC5813 = st.button("NGC5813")
# # if NGC4649:
# # uploaded_file = "NGC4649_example.fits"
# # elif NGC5813:
# # uploaded_file = "NGC5813_example.fits"
# # # If file is uploaded, read in the data and plot it
# # if uploaded_file is not None:
# # data, wcs = load_file(uploaded_file)
# # # if "data" not in locals():
# # # data = np.zeros((128,128))
# # # Make six columns for buttons
# # _, col1, col2, col3, col4, col5, col6, _ = st.columns([bordersize,0.5,0.5,0.5,0.5,0.5,0.5,bordersize])
# # col1.subheader("Input image")
# # col3.subheader("Prediction")
# # col5.subheader("Decomposed")
# # col6.subheader("")
# # with col1:
# # st.markdown("""<style>[data-baseweb="select"] {margin-top: -46px;}</style>""", unsafe_allow_html=True)
# # max_scale = int(data.shape[0] // 128)
# # scale = st.selectbox('Scale:',[f"{(i+1)*128}x{(i+1)*128}" for i in range(max_scale)], label_visibility="hidden", on_change=change_scale)
# # scale = int(scale.split("x")[0]) // 128
# # # Detect button
# # with col3: detect = st.button('Detect', key="detect")
# # # Threshold slider
# # with col4:
# # st.markdown("")
# # # st.markdown("""<style>[data-baseweb="select"] {margin-top: -36px;}</style>""", unsafe_allow_html=True)
# # threshold = st.slider("Threshold", 0.0, 1.0, 0.0, 0.05, key="threshold") #, label_visibility="hidden")
# # # Decompose button
# # with col5: decompose = st.button('Decompose', key="decompose")
# # # Make two columns for plots
# # _, colA, colB, colC, _ = st.columns([bordersize,1,1,1,bordersize])
# # image = np.log10(data+1)
# # plot_image(image, scale)
# # if detect or threshold:
# # # if st.session_state.get("detect", True):
# # y_pred, wcs = cut_n_predict(data, wcs, scale)
# # y_pred_th = np.where(y_pred > threshold, y_pred, 0)
# # plot_prediction(y_pred_th)
# # if decompose or st.session_state.get("download", False):
# # image_decomposed = decompose_cavity(y_pred_th)
# # plot_decomposed(image_decomposed)
# # with col6:
# # st.markdown("<br style='margin:4px 0'>", unsafe_allow_html=True)
# # # st.markdown("""<style>[data-baseweb="select"] {margin-top: 16px;}</style>""", unsafe_allow_html=True)
# # fname = uploaded_file.name.strip(".fits")
# # # if st.session_state.get("download", False):
# # shutil.make_archive("predictions", 'zip', "predictions")
# # with open('predictions.zip', 'rb') as f:
# # res = f.read()
# # download = st.download_button(label="Download", data=res, key="download",
# # file_name=f'{fname}_{int(scale*128)}.zip',
# # # disabled=st.session_state.get("disabled", True),
# # mime="application/octet-stream") |