Plsek commited on
Commit
15c5050
·
1 Parent(s): f27ae54

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -20
app.py CHANGED
@@ -18,12 +18,20 @@ model = from_pretrained_keras("Plsek/CADET-v1")
18
  import streamlit as st
19
  st.set_option('deprecation.showPyplotGlobalUse', False)
20
 
 
 
 
 
 
 
 
21
 
22
- st.title("Cavity Detection Tool")
 
23
 
24
  st.markdown("Cavity Detection Tool (CADET) is a machine learning pipeline trained to detect X-ray cavities from noisy Chandra images of early-type galaxies. To use this tool: upload your image, select the scale of interest, and make a prediction! If you use this tool for your research, please cite [Plšek et al. 2023](https://arxiv.org/abs/2304.05457)")
25
 
26
- st.markdown("Input images should be centered at the centre of the galaxy and point sources should be filled with surrounding background [dmfilth](https://cxc.cfa.harvard.edu/ciao/ahelp/dmfilth.html).")
27
 
28
 
29
  # Create file uploader widget
@@ -40,13 +48,6 @@ def plot_image(image, scale):
40
  plt.axis('off')
41
  with colA: st.pyplot()
42
 
43
- # Define function to smooth image
44
- def smooth_image(image, scale):
45
- smoothed = convolve(image, boundary = "wrap", nan_treatment="interpolate",
46
- kernel = Gauss(x_stddev = 2, y_stddev = 2))
47
-
48
- return smoothed
49
-
50
  # Define function to plot the prediction
51
  def plot_prediction(pred):
52
  plt.figure(figsize=(4, 4))
@@ -77,6 +78,32 @@ def cut(data0, wcs0, scale=1):
77
 
78
  return data, wcs
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  # If file is uploaded, read in the data and plot it
82
  if uploaded_file is not None:
@@ -88,24 +115,24 @@ if uploaded_file is not None:
88
  col1, col2, col3, col4 = st.columns(4)
89
  col1.subheader("Input image")
90
  col3.subheader("Prediction")
91
-
92
- # Add a slider to change the scale
93
- with col1:
94
- smooth = st.button("Smooth")
95
 
96
- with col2:
97
- st.markdown("""<style>[data-baseweb="select"] {margin-top: 17px;}</style>""", unsafe_allow_html=True)
98
  max_scale = int(data.shape[0] // 128)
99
- scale = int(st.selectbox('Scale:',[i+1 for i in range(max_scale)], label_visibility="hidden"))
 
 
100
 
101
- with col3:
102
  detect = st.button('Detect cavities')
103
 
 
 
 
104
  # Make two columns for plots
105
  colA, colB = st.columns(2)
106
 
107
  image = np.log10(data+1)
108
- if smooth: image = smooth_image(image, scale)
109
  plot_image(image, scale)
110
 
111
  if detect:
@@ -122,7 +149,10 @@ if uploaded_file is not None:
122
  # Thresholding
123
  y_pred = np.where(y_pred > 0.4, y_pred, 0)
124
 
125
- plot_prediction(y_pred)
 
 
 
126
 
127
  ccd = CCDData(y_pred, unit="adu", wcs=wcs)
128
  ccd.write("predicted.fits", overwrite=True)
@@ -131,6 +161,6 @@ if uploaded_file is not None:
131
 
132
  with col4:
133
  pass
134
- # st.markdown("""<style>[data-baseweb="select"] {margin-top: 32px;}</style>""", unsafe_allow_html=True)
135
  # # download = st.button('Download')
136
  download = st.download_button(label="Download", data=res, file_name="predicted.fits", mime="application/octet-stream")
 
18
  import streamlit as st
19
  st.set_option('deprecation.showPyplotGlobalUse', False)
20
 
21
+ st.set_page_config(
22
+ page_title="Cavity Detection Tool",
23
+ # page_icon="👋",
24
+ layout="wide"
25
+ # initial_sidebar_state="expanded",
26
+ }
27
+ )
28
 
29
+
30
+ # st.title("Cavity Detection Tool")
31
 
32
  st.markdown("Cavity Detection Tool (CADET) is a machine learning pipeline trained to detect X-ray cavities from noisy Chandra images of early-type galaxies. To use this tool: upload your image, select the scale of interest, and make a prediction! If you use this tool for your research, please cite [Plšek et al. 2023](https://arxiv.org/abs/2304.05457)")
33
 
34
+ st.markdown("Input images should be centered at the centre of the galaxy and point sources should be filled with surrounding background ([dmfilth](https://cxc.cfa.harvard.edu/ciao/ahelp/dmfilth.html)).")
35
 
36
 
37
  # Create file uploader widget
 
48
  plt.axis('off')
49
  with colA: st.pyplot()
50
 
 
 
 
 
 
 
 
51
  # Define function to plot the prediction
52
  def plot_prediction(pred):
53
  plt.figure(figsize=(4, 4))
 
78
 
79
  return data, wcs
80
 
81
+ def decompose_cavity(pred, th2=0.7, amin=10):
82
+ X, Y = pred.nonzero()
83
+ data = np.array([X,Y]).reshape(2, -1)
84
+
85
+ # DBSCAN CLUSTERING ALGORITHM
86
+ try: clusters = DBSCAN(eps=1.5, min_samples=3).fit(data.T).labels_
87
+ except: clusters = []
88
+
89
+ N = len(set(clusters))
90
+ cavities = []
91
+
92
+ for i in range(N):
93
+ img = np.zeros((128,128))
94
+ b = clusters == i
95
+ xi, yi = X[b], Y[b]
96
+ img[xi, yi] = pred[xi, yi]
97
+
98
+ # THRESHOLDING #2
99
+ if not (img > th2).any(): continue
100
+
101
+ # MINIMAL AREA
102
+ if np.sum(img) <= amin: continue
103
+
104
+ cavities.append(img)
105
+
106
+ return cavities
107
 
108
  # If file is uploaded, read in the data and plot it
109
  if uploaded_file is not None:
 
115
  col1, col2, col3, col4 = st.columns(4)
116
  col1.subheader("Input image")
117
  col3.subheader("Prediction")
 
 
 
 
118
 
119
+ with col1:
120
+ st.markdown("""<style>[data-baseweb="select"] {margin-top: 16px;}</style>""", unsafe_allow_html=True)
121
  max_scale = int(data.shape[0] // 128)
122
+ # scale = int(st.selectbox('Scale:',[i+1 for i in range(max_scale)], label_visibility="hidden"))
123
+ scale = int(st.selectbox('Scale:',[f"{(i+1)*128}x{(i+1)*128}" for i in range(max_scale)], label_visibility="hidden"))
124
+ scale = scale.split("x")[0] // 128
125
 
126
+ with col2:
127
  detect = st.button('Detect cavities')
128
 
129
+ with col3:
130
+ decompose = st.button('Docompose cavities')
131
+
132
  # Make two columns for plots
133
  colA, colB = st.columns(2)
134
 
135
  image = np.log10(data+1)
 
136
  plot_image(image, scale)
137
 
138
  if detect:
 
149
  # Thresholding
150
  y_pred = np.where(y_pred > 0.4, y_pred, 0)
151
 
152
+ # if decompose:
153
+ # cavs = decompose_cavity(y_pred, )
154
+
155
+ plot_prediction(y_pred, decompose)
156
 
157
  ccd = CCDData(y_pred, unit="adu", wcs=wcs)
158
  ccd.write("predicted.fits", overwrite=True)
 
161
 
162
  with col4:
163
  pass
164
+ st.markdown("""<style>[data-baseweb="select"] {margin-top: 16px;}</style>""", unsafe_allow_html=True)
165
  # # download = st.button('Download')
166
  download = st.download_button(label="Download", data=res, file_name="predicted.fits", mime="application/octet-stream")