Update app.py
Browse files
app.py
CHANGED
|
@@ -11,31 +11,39 @@ st.set_option('deprecation.showPyplotGlobalUse', False)
|
|
| 11 |
|
| 12 |
st.title("Cavity Detection Tool")
|
| 13 |
|
|
|
|
|
|
|
|
|
|
| 14 |
model = load_model("CADET.hdf5")
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
# Define function to plot the uploaded image
|
| 17 |
def plot_image(image_array, scale):
|
| 18 |
-
# st.set_plot_config(plt, figsize=(4, 4))
|
| 19 |
plt.figure(figsize=(4, 4))
|
| 20 |
-
|
| 21 |
x0 = image_array.shape[0] // 2 - scale * 128 / 2
|
| 22 |
plt.imshow(image_array, origin="lower")
|
| 23 |
plt.gca().add_patch(Rectangle((x0, x0), scale*128, scale*128, linewidth=1, edgecolor='w', facecolor='none'))
|
|
|
|
| 24 |
plt.axis('off')
|
| 25 |
-
|
|
|
|
| 26 |
|
| 27 |
# Define function to plot the prediction
|
| 28 |
def plot_prediction(image_array, pred):
|
| 29 |
-
|
| 30 |
-
plt.
|
| 31 |
-
plt.
|
| 32 |
-
plt.
|
| 33 |
-
plt.axis('off')
|
| 34 |
|
| 35 |
-
plt.subplot(1, 2, 2)
|
| 36 |
plt.imshow(pred, origin="lower")
|
| 37 |
plt.axis('off')
|
| 38 |
-
|
| 39 |
|
| 40 |
def cut(data0, wcs0, scale=1):
|
| 41 |
shape = data0.shape[0]
|
|
@@ -59,8 +67,6 @@ def cut(data0, wcs0, scale=1):
|
|
| 59 |
|
| 60 |
return data, wcs
|
| 61 |
|
| 62 |
-
# Create file uploader widget
|
| 63 |
-
uploaded_file = st.file_uploader("Choose a FITS file", type=['fits'])
|
| 64 |
|
| 65 |
# If file is uploaded, read in the data and plot it
|
| 66 |
if uploaded_file is not None:
|
|
|
|
| 11 |
|
| 12 |
st.title("Cavity Detection Tool")
|
| 13 |
|
| 14 |
+
st.text("Cavity Detection Tool (CADET) is a machine learning pipeline trained to detect X-ray cavities from noisy Chandra images of early-type galaxies.\
|
| 15 |
+
To use this tool upload your image, select the scale of interest and make a prediction! If you use output of this tool in your research please cite [Plšek et al. 2023](https://arxiv.org/abs/2304.05457)")
|
| 16 |
+
|
| 17 |
model = load_model("CADET.hdf5")
|
| 18 |
|
| 19 |
+
# Create file uploader widget
|
| 20 |
+
uploaded_file = st.file_uploader("Choose a FITS file", type=['fits'])
|
| 21 |
+
|
| 22 |
+
col1, col2 = st.columns(2)
|
| 23 |
+
|
| 24 |
# Define function to plot the uploaded image
|
| 25 |
def plot_image(image_array, scale):
|
|
|
|
| 26 |
plt.figure(figsize=(4, 4))
|
| 27 |
+
|
| 28 |
x0 = image_array.shape[0] // 2 - scale * 128 / 2
|
| 29 |
plt.imshow(image_array, origin="lower")
|
| 30 |
plt.gca().add_patch(Rectangle((x0, x0), scale*128, scale*128, linewidth=1, edgecolor='w', facecolor='none'))
|
| 31 |
+
|
| 32 |
plt.axis('off')
|
| 33 |
+
fig.set_size_inches((4,4))
|
| 34 |
+
st.pyplot()
|
| 35 |
|
| 36 |
# Define function to plot the prediction
|
| 37 |
def plot_prediction(image_array, pred):
|
| 38 |
+
plt.figure(figsize=(4, 4))
|
| 39 |
+
# plt.subplot(1, 2, 1)
|
| 40 |
+
# plt.imshow(image_array, origin="lower")
|
| 41 |
+
# plt.axis('off')
|
|
|
|
| 42 |
|
| 43 |
+
# plt.subplot(1, 2, 2)
|
| 44 |
plt.imshow(pred, origin="lower")
|
| 45 |
plt.axis('off')
|
| 46 |
+
col2.pyplot()
|
| 47 |
|
| 48 |
def cut(data0, wcs0, scale=1):
|
| 49 |
shape = data0.shape[0]
|
|
|
|
| 67 |
|
| 68 |
return data, wcs
|
| 69 |
|
|
|
|
|
|
|
| 70 |
|
| 71 |
# If file is uploaded, read in the data and plot it
|
| 72 |
if uploaded_file is not None:
|