Update app.py
Browse files
app.py
CHANGED
@@ -1,27 +1,33 @@
|
|
1 |
-
|
2 |
import numpy as np
|
3 |
import matplotlib.pyplot as plt
|
4 |
from matplotlib.patches import Rectangle
|
|
|
|
|
5 |
from astropy.io import fits
|
6 |
from astropy.wcs import WCS
|
7 |
from astropy.nddata import Cutout2D, CCDData
|
|
|
|
|
8 |
from tensorflow.keras.models import load_model
|
|
|
9 |
|
|
|
|
|
10 |
st.set_option('deprecation.showPyplotGlobalUse', False)
|
11 |
|
|
|
12 |
st.title("Cavity Detection Tool")
|
13 |
|
14 |
st.markdown("Cavity Detection Tool (CADET) is a machine learning pipeline trained to detect X-ray cavities from noisy Chandra images of early-type galaxies.To use this tool upload your image, select the scale of interest and make a prediction! If you use output of this tool in your research please cite [Plšek et al. 2023](https://arxiv.org/abs/2304.05457)")
|
15 |
|
16 |
-
model = load_model("CADET.hdf5")
|
17 |
-
|
18 |
# Create file uploader widget
|
19 |
uploaded_file = st.file_uploader("Choose a FITS file", type=['fits'])
|
20 |
|
|
|
21 |
col1, col2 = st.columns(2)
|
22 |
-
|
23 |
col1.subheader("Input image")
|
24 |
-
|
25 |
|
26 |
# Define function to plot the uploaded image
|
27 |
def plot_image(image_array, scale):
|
@@ -32,7 +38,6 @@ def plot_image(image_array, scale):
|
|
32 |
plt.gca().add_patch(Rectangle((x0, x0), scale*128, scale*128, linewidth=1, edgecolor='w', facecolor='none'))
|
33 |
|
34 |
plt.axis('off')
|
35 |
-
# with col1:
|
36 |
st.pyplot()
|
37 |
|
38 |
# Define function to plot the prediction
|
@@ -43,6 +48,7 @@ def plot_prediction(pred):
|
|
43 |
# with col2:
|
44 |
st.pyplot()
|
45 |
|
|
|
46 |
def cut(data0, wcs0, scale=1):
|
47 |
shape = data0.shape[0]
|
48 |
x0 = shape / 2
|
@@ -50,11 +56,11 @@ def cut(data0, wcs0, scale=1):
|
|
50 |
cutout = Cutout2D(data0, (x0, x0), (size, size), wcs=wcs0)
|
51 |
data, wcs = cutout.data, cutout.wcs
|
52 |
|
53 |
-
#
|
54 |
factor = size // 128
|
55 |
data = data.reshape(128, factor, 128, factor).mean(-1).mean(1)
|
56 |
|
57 |
-
#
|
58 |
ra, dec = wcs.wcs_pix2world(np.array([[63, 63]]),0)[0]
|
59 |
wcs.wcs.cdelt[0] = wcs.wcs.cdelt[0] * factor
|
60 |
wcs.wcs.cdelt[1] = wcs.wcs.cdelt[1] * factor
|
@@ -71,33 +77,32 @@ if uploaded_file is not None:
|
|
71 |
with fits.open(uploaded_file) as hdul:
|
72 |
data = hdul[0].data
|
73 |
wcs = WCS(hdul[0].header)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
|
|
|
|
79 |
|
80 |
-
|
81 |
|
82 |
-
|
83 |
-
|
84 |
-
data, wcs = cut(data, wcs, scale=scale)
|
85 |
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
rotated = np.rot90(image_data, j)
|
91 |
-
pred = model.predict(rotated.reshape(1, 128, 128, 1)).reshape(128 ,128)
|
92 |
-
pred = np.rot90(pred, -j)
|
93 |
-
y_pred += pred / 4
|
94 |
-
|
95 |
-
# ccd = CCDData(pred, unit="adu", wcs=wcs)
|
96 |
-
# ccd.write(f"predicted.fits", overwrite=True)
|
97 |
-
|
98 |
-
plot_prediction(y_pred)
|
99 |
-
|
100 |
-
# if st.button('Download FITS File'):
|
101 |
-
# with open('predicted.fits', 'rb') as f:
|
102 |
-
# data = f.read()
|
103 |
-
# st.download_button(label="Download", data=data, file_name="predicted.fits", mime="application/octet-stream")
|
|
|
1 |
+
# Basic libraries
|
2 |
import numpy as np
|
3 |
import matplotlib.pyplot as plt
|
4 |
from matplotlib.patches import Rectangle
|
5 |
+
|
6 |
+
# Astropy
|
7 |
from astropy.io import fits
|
8 |
from astropy.wcs import WCS
|
9 |
from astropy.nddata import Cutout2D, CCDData
|
10 |
+
|
11 |
+
# Tensorflow
|
12 |
from tensorflow.keras.models import load_model
|
13 |
+
model = load_model("CADET.hdf5")
|
14 |
|
15 |
+
# Streamlit
|
16 |
+
import streamlit as st
|
17 |
st.set_option('deprecation.showPyplotGlobalUse', False)
|
18 |
|
19 |
+
|
20 |
st.title("Cavity Detection Tool")
|
21 |
|
22 |
st.markdown("Cavity Detection Tool (CADET) is a machine learning pipeline trained to detect X-ray cavities from noisy Chandra images of early-type galaxies.To use this tool upload your image, select the scale of interest and make a prediction! If you use output of this tool in your research please cite [Plšek et al. 2023](https://arxiv.org/abs/2304.05457)")
|
23 |
|
|
|
|
|
24 |
# Create file uploader widget
|
25 |
uploaded_file = st.file_uploader("Choose a FITS file", type=['fits'])
|
26 |
|
27 |
+
# Make two columns
|
28 |
col1, col2 = st.columns(2)
|
|
|
29 |
col1.subheader("Input image")
|
30 |
+
col2.subheader("CADET prediction")
|
31 |
|
32 |
# Define function to plot the uploaded image
|
33 |
def plot_image(image_array, scale):
|
|
|
38 |
plt.gca().add_patch(Rectangle((x0, x0), scale*128, scale*128, linewidth=1, edgecolor='w', facecolor='none'))
|
39 |
|
40 |
plt.axis('off')
|
|
|
41 |
st.pyplot()
|
42 |
|
43 |
# Define function to plot the prediction
|
|
|
48 |
# with col2:
|
49 |
st.pyplot()
|
50 |
|
51 |
+
# Cut input image and rebin it to 128x128 pixels
|
52 |
def cut(data0, wcs0, scale=1):
|
53 |
shape = data0.shape[0]
|
54 |
x0 = shape / 2
|
|
|
56 |
cutout = Cutout2D(data0, (x0, x0), (size, size), wcs=wcs0)
|
57 |
data, wcs = cutout.data, cutout.wcs
|
58 |
|
59 |
+
# Regrid data
|
60 |
factor = size // 128
|
61 |
data = data.reshape(128, factor, 128, factor).mean(-1).mean(1)
|
62 |
|
63 |
+
# Regrid wcs
|
64 |
ra, dec = wcs.wcs_pix2world(np.array([[63, 63]]),0)[0]
|
65 |
wcs.wcs.cdelt[0] = wcs.wcs.cdelt[0] * factor
|
66 |
wcs.wcs.cdelt[1] = wcs.wcs.cdelt[1] * factor
|
|
|
77 |
with fits.open(uploaded_file) as hdul:
|
78 |
data = hdul[0].data
|
79 |
wcs = WCS(hdul[0].header)
|
80 |
+
image = np.log10(data+1)
|
81 |
+
|
82 |
+
# Add a slider to change the scale
|
83 |
+
with col1:
|
84 |
+
max_scale = int(data.shape[0] // 128)
|
85 |
+
scale = st.slider("Scale", 1, max_scale, 1, 1)
|
86 |
+
|
87 |
+
plot_image(image, scale)
|
88 |
+
|
89 |
+
with col2:
|
90 |
+
st.button('Detect cavities')
|
91 |
+
data, wcs = cut(data, wcs, scale=scale)
|
92 |
|
93 |
+
y_pred = 0
|
94 |
+
for j in [0,1,2,3]:
|
95 |
+
rotated = np.rot90(image, j)
|
96 |
+
pred = model.predict(rotated.reshape(1, 128, 128, 1)).reshape(128 ,128)
|
97 |
+
pred = np.rot90(pred, -j)
|
98 |
+
y_pred += pred / 4
|
99 |
|
100 |
+
plot_prediction(y_pred)
|
101 |
|
102 |
+
# ccd = CCDData(pred, unit="adu", wcs=wcs)
|
103 |
+
# ccd.write(f"predicted.fits", overwrite=True)
|
|
|
104 |
|
105 |
+
# if st.button('Download FITS File'):
|
106 |
+
# with open('predicted.fits', 'rb') as f:
|
107 |
+
# data = f.read()
|
108 |
+
# st.download_button(label="Download", data=data, file_name="predicted.fits", mime="application/octet-stream")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|