Plsek commited on
Commit
7454012
·
1 Parent(s): 588ff75

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -34
app.py CHANGED
@@ -1,27 +1,33 @@
1
- import streamlit as st
2
  import numpy as np
3
  import matplotlib.pyplot as plt
4
  from matplotlib.patches import Rectangle
 
 
5
  from astropy.io import fits
6
  from astropy.wcs import WCS
7
  from astropy.nddata import Cutout2D, CCDData
 
 
8
  from tensorflow.keras.models import load_model
 
9
 
 
 
10
  st.set_option('deprecation.showPyplotGlobalUse', False)
11
 
 
12
  st.title("Cavity Detection Tool")
13
 
14
  st.markdown("Cavity Detection Tool (CADET) is a machine learning pipeline trained to detect X-ray cavities from noisy Chandra images of early-type galaxies.To use this tool upload your image, select the scale of interest and make a prediction! If you use output of this tool in your research please cite [Plšek et al. 2023](https://arxiv.org/abs/2304.05457)")
15
 
16
- model = load_model("CADET.hdf5")
17
-
18
  # Create file uploader widget
19
  uploaded_file = st.file_uploader("Choose a FITS file", type=['fits'])
20
 
 
21
  col1, col2 = st.columns(2)
22
-
23
  col1.subheader("Input image")
24
- col.subheader("CADET prediction")
25
 
26
  # Define function to plot the uploaded image
27
  def plot_image(image_array, scale):
@@ -32,7 +38,6 @@ def plot_image(image_array, scale):
32
  plt.gca().add_patch(Rectangle((x0, x0), scale*128, scale*128, linewidth=1, edgecolor='w', facecolor='none'))
33
 
34
  plt.axis('off')
35
- # with col1:
36
  st.pyplot()
37
 
38
  # Define function to plot the prediction
@@ -43,6 +48,7 @@ def plot_prediction(pred):
43
  # with col2:
44
  st.pyplot()
45
 
 
46
  def cut(data0, wcs0, scale=1):
47
  shape = data0.shape[0]
48
  x0 = shape / 2
@@ -50,11 +56,11 @@ def cut(data0, wcs0, scale=1):
50
  cutout = Cutout2D(data0, (x0, x0), (size, size), wcs=wcs0)
51
  data, wcs = cutout.data, cutout.wcs
52
 
53
- # REGRID DATA
54
  factor = size // 128
55
  data = data.reshape(128, factor, 128, factor).mean(-1).mean(1)
56
 
57
- # REGIRD WCS
58
  ra, dec = wcs.wcs_pix2world(np.array([[63, 63]]),0)[0]
59
  wcs.wcs.cdelt[0] = wcs.wcs.cdelt[0] * factor
60
  wcs.wcs.cdelt[1] = wcs.wcs.cdelt[1] * factor
@@ -71,33 +77,32 @@ if uploaded_file is not None:
71
  with fits.open(uploaded_file) as hdul:
72
  data = hdul[0].data
73
  wcs = WCS(hdul[0].header)
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
- # Add a slider to change the scale
76
- with col1:
77
- max_scale = int(data.shape[0] // 128)
78
- scale = st.slider("Scale", 1, max_scale, 1, 1)
 
 
79
 
80
- plot_image(np.log10(data+1), scale)
81
 
82
- with col2:
83
- if st.button('Detect cavities'):
84
- data, wcs = cut(data, wcs, scale=scale)
85
 
86
- image_data = np.log10(data+1)
87
-
88
- y_pred = 0
89
- for j in [0,1,2,3]:
90
- rotated = np.rot90(image_data, j)
91
- pred = model.predict(rotated.reshape(1, 128, 128, 1)).reshape(128 ,128)
92
- pred = np.rot90(pred, -j)
93
- y_pred += pred / 4
94
-
95
- # ccd = CCDData(pred, unit="adu", wcs=wcs)
96
- # ccd.write(f"predicted.fits", overwrite=True)
97
-
98
- plot_prediction(y_pred)
99
-
100
- # if st.button('Download FITS File'):
101
- # with open('predicted.fits', 'rb') as f:
102
- # data = f.read()
103
- # st.download_button(label="Download", data=data, file_name="predicted.fits", mime="application/octet-stream")
 
1
+ # Basic libraries
2
  import numpy as np
3
  import matplotlib.pyplot as plt
4
  from matplotlib.patches import Rectangle
5
+
6
+ # Astropy
7
  from astropy.io import fits
8
  from astropy.wcs import WCS
9
  from astropy.nddata import Cutout2D, CCDData
10
+
11
+ # Tensorflow
12
  from tensorflow.keras.models import load_model
13
+ model = load_model("CADET.hdf5")
14
 
15
+ # Streamlit
16
+ import streamlit as st
17
  st.set_option('deprecation.showPyplotGlobalUse', False)
18
 
19
+
20
  st.title("Cavity Detection Tool")
21
 
22
  st.markdown("Cavity Detection Tool (CADET) is a machine learning pipeline trained to detect X-ray cavities from noisy Chandra images of early-type galaxies.To use this tool upload your image, select the scale of interest and make a prediction! If you use output of this tool in your research please cite [Plšek et al. 2023](https://arxiv.org/abs/2304.05457)")
23
 
 
 
24
  # Create file uploader widget
25
  uploaded_file = st.file_uploader("Choose a FITS file", type=['fits'])
26
 
27
+ # Make two columns
28
  col1, col2 = st.columns(2)
 
29
  col1.subheader("Input image")
30
+ col2.subheader("CADET prediction")
31
 
32
  # Define function to plot the uploaded image
33
  def plot_image(image_array, scale):
 
38
  plt.gca().add_patch(Rectangle((x0, x0), scale*128, scale*128, linewidth=1, edgecolor='w', facecolor='none'))
39
 
40
  plt.axis('off')
 
41
  st.pyplot()
42
 
43
  # Define function to plot the prediction
 
48
  # with col2:
49
  st.pyplot()
50
 
51
+ # Cut input image and rebin it to 128x128 pixels
52
  def cut(data0, wcs0, scale=1):
53
  shape = data0.shape[0]
54
  x0 = shape / 2
 
56
  cutout = Cutout2D(data0, (x0, x0), (size, size), wcs=wcs0)
57
  data, wcs = cutout.data, cutout.wcs
58
 
59
+ # Regrid data
60
  factor = size // 128
61
  data = data.reshape(128, factor, 128, factor).mean(-1).mean(1)
62
 
63
+ # Regrid wcs
64
  ra, dec = wcs.wcs_pix2world(np.array([[63, 63]]),0)[0]
65
  wcs.wcs.cdelt[0] = wcs.wcs.cdelt[0] * factor
66
  wcs.wcs.cdelt[1] = wcs.wcs.cdelt[1] * factor
 
77
  with fits.open(uploaded_file) as hdul:
78
  data = hdul[0].data
79
  wcs = WCS(hdul[0].header)
80
+ image = np.log10(data+1)
81
+
82
+ # Add a slider to change the scale
83
+ with col1:
84
+ max_scale = int(data.shape[0] // 128)
85
+ scale = st.slider("Scale", 1, max_scale, 1, 1)
86
+
87
+ plot_image(image, scale)
88
+
89
+ with col2:
90
+ st.button('Detect cavities')
91
+ data, wcs = cut(data, wcs, scale=scale)
92
 
93
+ y_pred = 0
94
+ for j in [0,1,2,3]:
95
+ rotated = np.rot90(image, j)
96
+ pred = model.predict(rotated.reshape(1, 128, 128, 1)).reshape(128 ,128)
97
+ pred = np.rot90(pred, -j)
98
+ y_pred += pred / 4
99
 
100
+ plot_prediction(y_pred)
101
 
102
+ # ccd = CCDData(pred, unit="adu", wcs=wcs)
103
+ # ccd.write(f"predicted.fits", overwrite=True)
 
104
 
105
+ # if st.button('Download FITS File'):
106
+ # with open('predicted.fits', 'rb') as f:
107
+ # data = f.read()
108
+ # st.download_button(label="Download", data=data, file_name="predicted.fits", mime="application/octet-stream")