Plsek commited on
Commit
c1f2126
·
1 Parent(s): 415f0ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -12
app.py CHANGED
@@ -11,31 +11,39 @@ st.set_option('deprecation.showPyplotGlobalUse', False)
11
 
12
  st.title("Cavity Detection Tool")
13
 
 
 
 
14
  model = load_model("CADET.hdf5")
15
 
 
 
 
 
 
16
  # Define function to plot the uploaded image
17
  def plot_image(image_array, scale):
18
- # st.set_plot_config(plt, figsize=(4, 4))
19
  plt.figure(figsize=(4, 4))
20
- # plt.subplot(1, 2, 1)
21
  x0 = image_array.shape[0] // 2 - scale * 128 / 2
22
  plt.imshow(image_array, origin="lower")
23
  plt.gca().add_patch(Rectangle((x0, x0), scale*128, scale*128, linewidth=1, edgecolor='w', facecolor='none'))
 
24
  plt.axis('off')
25
- st.pyplot(width=200)
 
26
 
27
  # Define function to plot the prediction
28
  def plot_prediction(image_array, pred):
29
- # st.set_plot_config(plt, figsize=(8, 4))
30
- plt.figure(figsize=(8, 4))
31
- plt.subplot(1, 2, 1)
32
- plt.imshow(image_array, origin="lower")
33
- plt.axis('off')
34
 
35
- plt.subplot(1, 2, 2)
36
  plt.imshow(pred, origin="lower")
37
  plt.axis('off')
38
- st.pyplot(width=400)
39
 
40
  def cut(data0, wcs0, scale=1):
41
  shape = data0.shape[0]
@@ -59,8 +67,6 @@ def cut(data0, wcs0, scale=1):
59
 
60
  return data, wcs
61
 
62
- # Create file uploader widget
63
- uploaded_file = st.file_uploader("Choose a FITS file", type=['fits'])
64
 
65
  # If file is uploaded, read in the data and plot it
66
  if uploaded_file is not None:
 
11
 
12
  st.title("Cavity Detection Tool")
13
 
14
+ st.text("Cavity Detection Tool (CADET) is a machine learning pipeline trained to detect X-ray cavities from noisy Chandra images of early-type galaxies.\
15
+ To use this tool upload your image, select the scale of interest and make a prediction! If you use output of this tool in your research please cite [Plšek et al. 2023](https://arxiv.org/abs/2304.05457)")
16
+
17
  model = load_model("CADET.hdf5")
18
 
19
+ # Create file uploader widget
20
+ uploaded_file = st.file_uploader("Choose a FITS file", type=['fits'])
21
+
22
+ col1, col2 = st.columns(2)
23
+
24
  # Define function to plot the uploaded image
25
  def plot_image(image_array, scale):
 
26
  plt.figure(figsize=(4, 4))
27
+
28
  x0 = image_array.shape[0] // 2 - scale * 128 / 2
29
  plt.imshow(image_array, origin="lower")
30
  plt.gca().add_patch(Rectangle((x0, x0), scale*128, scale*128, linewidth=1, edgecolor='w', facecolor='none'))
31
+
32
  plt.axis('off')
33
+ fig.set_size_inches((4,4))
34
+ st.pyplot()
35
 
36
  # Define function to plot the prediction
37
  def plot_prediction(image_array, pred):
38
+ plt.figure(figsize=(4, 4))
39
+ # plt.subplot(1, 2, 1)
40
+ # plt.imshow(image_array, origin="lower")
41
+ # plt.axis('off')
 
42
 
43
+ # plt.subplot(1, 2, 2)
44
  plt.imshow(pred, origin="lower")
45
  plt.axis('off')
46
+ col2.pyplot()
47
 
48
  def cut(data0, wcs0, scale=1):
49
  shape = data0.shape[0]
 
67
 
68
  return data, wcs
69
 
 
 
70
 
71
  # If file is uploaded, read in the data and plot it
72
  if uploaded_file is not None: