# HuggingFace Hub from huggingface_hub import from_pretrained_keras model = from_pretrained_keras("Plsek/CADET-v1") # Basic libraries import os import shutil import numpy as np from scipy.ndimage import center_of_mass import matplotlib.pyplot as plt from matplotlib.colors import LogNorm from matplotlib.patches import Rectangle # Astropy from astropy.io import fits from astropy.wcs import WCS from astropy.nddata import Cutout2D, CCDData from astropy.convolution import Gaussian2DKernel as Gauss from astropy.convolution import convolve # Scikit-learn from sklearn.cluster import DBSCAN # Streamlit import streamlit as st st.set_option('deprecation.showPyplotGlobalUse', False) # # Define function to plot the uploaded image # def plot_image(image, scale): # plt.figure(figsize=(4, 4)) # x0 = image.shape[0] // 2 - scale * 128 / 2 # plt.imshow(image, origin="lower") # plt.gca().add_patch(Rectangle((x0-0.5, x0-0.5), scale*128, scale*128, linewidth=1, edgecolor='w', facecolor='none')) # plt.axis('off') # plt.tight_layout() # with colA: st.pyplot() # # Define function to plot the prediction # def plot_prediction(pred): # plt.figure(figsize=(4, 4)) # plt.imshow(pred, origin="lower") # plt.axis('off') # with colB: st.pyplot() # # Define function to plot the decomposed prediction # def plot_decomposed(decomposed): # plt.figure(figsize=(4, 4)) # plt.imshow(decomposed, origin="lower") #, norm=LogNorm()) # N = int(np.max(decomposed)) # for i in range(N): # new = np.where(decomposed == i+1, 1, 0) # x0, y0 = center_of_mass(new) # color = "white" if i < N//2 else "black" # plt.text(y0, x0, f"{i+1}", ha="center", va="center", fontsize=15, color=color) # plt.axis('off') # with colC: st.pyplot() # # Define function to cut input image and rebin it to 128x128 pixels # def cut(data0, wcs0, scale=1): # shape = data0.shape[0] # x0 = shape / 2 # size = 128 * scale # cutout = Cutout2D(data0, (x0, x0), (size, size), wcs=wcs0) # data, wcs = cutout.data, cutout.wcs # # Regrid data # factor = size // 128 # data = data.reshape(128, factor, 128, factor).mean(-1).mean(1) # # Regrid wcs # ra, dec = wcs.wcs_pix2world(np.array([[63, 63]]),0)[0] # wcs.wcs.cdelt[0] = wcs.wcs.cdelt[0] * factor # wcs.wcs.cdelt[1] = wcs.wcs.cdelt[1] * factor # wcs.wcs.crval[0] = ra # wcs.wcs.crval[1] = dec # wcs.wcs.crpix[0] = 64 / factor # wcs.wcs.crpix[1] = 64 / factor # return data, wcs # # Define function to apply cutting and produce a prediction # # @st.cache # def cut_n_predict(data, wcs, scale): # data, wcs = cut(data, wcs, scale=scale) # image = np.log10(data+1) # y_pred = 0 # for j in [0,1,2,3]: # rotated = np.rot90(image, j) # pred = model.predict(rotated.reshape(1, 128, 128, 1)).reshape(128 ,128) # pred = np.rot90(pred, -j) # y_pred += pred / 4 # return y_pred, wcs # # Define function to decompose prediction into individual cavities # # @st.cache # def decompose_cavity(pred, th2=0.7, amin=6): # X, Y = pred.nonzero() # data = np.array([X,Y]).reshape(2, -1) # # DBSCAN clustering # try: clusters = DBSCAN(eps=1.0, min_samples=3).fit(data.T).labels_ # except: clusters = [] # N = len(set(clusters)) # cavities = [] # for i in range(N): # img = np.zeros((128,128)) # b = clusters == i # xi, yi = X[b], Y[b] # img[xi, yi] = pred[xi, yi] # # # Thresholding #2 # # if not (img > th2).any(): continue # # Minimal area # if np.sum(img) <= amin: continue # cavities.append(img) # # Save raw and decomposed predictions to predictions folder # ccd = CCDData(pred, unit="adu", wcs=wcs) # ccd.write(f"predictions/predicted.fits", overwrite=True) # image_decomposed = np.zeros((128,128)) # for i, cav in enumerate(cavities): # ccd = CCDData(cav, unit="adu", wcs=wcs) # ccd.write(f"predictions/predicted_{i+1}.fits", overwrite=True) # image_decomposed += (i+1) * np.where(cav > 0, 1, 0) # # shutil.make_archive("predictions", 'zip', "predictions") # return image_decomposed # # @st.cache # def load_file(fname): # with fits.open(fname) as hdul: # data = hdul[0].data # wcs = WCS(hdul[0].header) # return data, wcs # def change_scale(): # del st.session_state["threshold"] # # Use wide layout and create columns # st.set_page_config(page_title="Cavity Detection Tool", layout="wide") # bordersize = 0.45 # _, col, _ = st.columns([bordersize, 3, bordersize]) # os.system("mkdir -p predictions") # with col: # # Create heading and description # st.markdown("