Update app.py
Browse files
app.py
CHANGED
@@ -11,31 +11,39 @@ st.set_option('deprecation.showPyplotGlobalUse', False)
|
|
11 |
|
12 |
st.title("Cavity Detection Tool")
|
13 |
|
|
|
|
|
|
|
14 |
model = load_model("CADET.hdf5")
|
15 |
|
|
|
|
|
|
|
|
|
|
|
16 |
# Define function to plot the uploaded image
|
17 |
def plot_image(image_array, scale):
|
18 |
-
# st.set_plot_config(plt, figsize=(4, 4))
|
19 |
plt.figure(figsize=(4, 4))
|
20 |
-
|
21 |
x0 = image_array.shape[0] // 2 - scale * 128 / 2
|
22 |
plt.imshow(image_array, origin="lower")
|
23 |
plt.gca().add_patch(Rectangle((x0, x0), scale*128, scale*128, linewidth=1, edgecolor='w', facecolor='none'))
|
|
|
24 |
plt.axis('off')
|
25 |
-
|
|
|
26 |
|
27 |
# Define function to plot the prediction
|
28 |
def plot_prediction(image_array, pred):
|
29 |
-
|
30 |
-
plt.
|
31 |
-
plt.
|
32 |
-
plt.
|
33 |
-
plt.axis('off')
|
34 |
|
35 |
-
plt.subplot(1, 2, 2)
|
36 |
plt.imshow(pred, origin="lower")
|
37 |
plt.axis('off')
|
38 |
-
|
39 |
|
40 |
def cut(data0, wcs0, scale=1):
|
41 |
shape = data0.shape[0]
|
@@ -59,8 +67,6 @@ def cut(data0, wcs0, scale=1):
|
|
59 |
|
60 |
return data, wcs
|
61 |
|
62 |
-
# Create file uploader widget
|
63 |
-
uploaded_file = st.file_uploader("Choose a FITS file", type=['fits'])
|
64 |
|
65 |
# If file is uploaded, read in the data and plot it
|
66 |
if uploaded_file is not None:
|
|
|
11 |
|
12 |
st.title("Cavity Detection Tool")
|
13 |
|
14 |
+
st.text("Cavity Detection Tool (CADET) is a machine learning pipeline trained to detect X-ray cavities from noisy Chandra images of early-type galaxies.\
|
15 |
+
To use this tool upload your image, select the scale of interest and make a prediction! If you use output of this tool in your research please cite [Plšek et al. 2023](https://arxiv.org/abs/2304.05457)")
|
16 |
+
|
17 |
model = load_model("CADET.hdf5")
|
18 |
|
19 |
+
# Create file uploader widget
|
20 |
+
uploaded_file = st.file_uploader("Choose a FITS file", type=['fits'])
|
21 |
+
|
22 |
+
col1, col2 = st.columns(2)
|
23 |
+
|
24 |
# Define function to plot the uploaded image
|
25 |
def plot_image(image_array, scale):
|
|
|
26 |
plt.figure(figsize=(4, 4))
|
27 |
+
|
28 |
x0 = image_array.shape[0] // 2 - scale * 128 / 2
|
29 |
plt.imshow(image_array, origin="lower")
|
30 |
plt.gca().add_patch(Rectangle((x0, x0), scale*128, scale*128, linewidth=1, edgecolor='w', facecolor='none'))
|
31 |
+
|
32 |
plt.axis('off')
|
33 |
+
fig.set_size_inches((4,4))
|
34 |
+
st.pyplot()
|
35 |
|
36 |
# Define function to plot the prediction
|
37 |
def plot_prediction(image_array, pred):
|
38 |
+
plt.figure(figsize=(4, 4))
|
39 |
+
# plt.subplot(1, 2, 1)
|
40 |
+
# plt.imshow(image_array, origin="lower")
|
41 |
+
# plt.axis('off')
|
|
|
42 |
|
43 |
+
# plt.subplot(1, 2, 2)
|
44 |
plt.imshow(pred, origin="lower")
|
45 |
plt.axis('off')
|
46 |
+
col2.pyplot()
|
47 |
|
48 |
def cut(data0, wcs0, scale=1):
|
49 |
shape = data0.shape[0]
|
|
|
67 |
|
68 |
return data, wcs
|
69 |
|
|
|
|
|
70 |
|
71 |
# If file is uploaded, read in the data and plot it
|
72 |
if uploaded_file is not None:
|