Plsek commited on
Commit
b635f79
·
1 Parent(s): ec61752

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -32
app.py CHANGED
@@ -24,9 +24,6 @@ st.markdown("Cavity Detection Tool (CADET) is a machine learning pipeline traine
24
  # Create file uploader widget
25
  uploaded_file = st.file_uploader("Choose a FITS file", type=['fits'])
26
 
27
- # Make two columns
28
- col1, col2 = st.columns(2)
29
-
30
  # Define function to plot the uploaded image
31
  def plot_image(image_array, scale):
32
  plt.figure(figsize=(4, 4))
@@ -69,45 +66,53 @@ def cut(data0, wcs0, scale=1):
69
 
70
  return data, wcs
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  # If file is uploaded, read in the data and plot it
73
  if uploaded_file is not None:
74
- col1.subheader("Input image")
75
- col2.subheader("CADET prediction")
76
-
77
  with fits.open(uploaded_file) as hdul:
78
  data = hdul[0].data
79
  wcs = WCS(hdul[0].header)
80
 
81
  # Add a slider to change the scale
82
  with col1:
83
- max_scale = int(data.shape[0] // 128)
84
- # scale = st.slider("Scale", 1, max_scale, 1, 1)
85
-
86
- st.markdown(
87
- """<style>[data-baseweb="select"] {margin-top: -50px;}</style>""",
88
- unsafe_allow_html=True
89
- )
90
-
91
- scale = int(st.selectbox('Scale:',[i+1 for i in range(max_scale)], label_visibility="hidden"))
92
-
93
  plot_image(np.log10(data+1), scale)
94
 
95
- with col2:
96
- if st.button('Detect cavities', disabled=True):
97
- data, wcs = cut(data, wcs, scale=scale)
98
-
99
- image = np.log10(data+1)
100
-
101
- y_pred = 0
102
- for j in [0,1,2,3]:
103
- rotated = np.rot90(image, j)
104
- pred = model.predict(rotated.reshape(1, 128, 128, 1)).reshape(128 ,128)
105
- pred = np.rot90(pred, -j)
106
- y_pred += pred / 4
107
-
108
- # Thresholding
109
- y_pred = np.where(y_pred > 0.4, y_pred, 0)
110
-
 
 
111
  plot_prediction(y_pred)
112
 
113
  # ccd = CCDData(pred, unit="adu", wcs=wcs)
 
24
  # Create file uploader widget
25
  uploaded_file = st.file_uploader("Choose a FITS file", type=['fits'])
26
 
 
 
 
27
  # Define function to plot the uploaded image
28
  def plot_image(image_array, scale):
29
  plt.figure(figsize=(4, 4))
 
66
 
67
  return data, wcs
68
 
69
+ # Make two columns
70
+ col1, col2 = st.columns(2)
71
+ col1.subheader("Input image")
72
+ col2.subheader("CADET prediction")
73
+
74
+ with col1:
75
+ st.markdown(
76
+ """<style>[data-baseweb="select"] {margin-top: -50px;}</style>""",
77
+ unsafe_allow_html=True
78
+ )
79
+
80
+ max_scale = int(data.shape[0] // 128)
81
+ # scale = st.slider("Scale", 1, max_scale, 1, 1)
82
+ scale = int(st.selectbox('Scale:',[i+1 for i in range(max_scale)], label_visibility="hidden"))
83
+
84
+ with col2:
85
+ button == st.button('Detect cavities', disabled=True)
86
+
87
+
88
  # If file is uploaded, read in the data and plot it
89
  if uploaded_file is not None:
 
 
 
90
  with fits.open(uploaded_file) as hdul:
91
  data = hdul[0].data
92
  wcs = WCS(hdul[0].header)
93
 
94
  # Add a slider to change the scale
95
  with col1:
 
 
 
 
 
 
 
 
 
 
96
  plot_image(np.log10(data+1), scale)
97
 
98
+ button.disabled = False
99
+
100
+ if button:
101
+ data, wcs = cut(data, wcs, scale=scale)
102
+
103
+ image = np.log10(data+1)
104
+
105
+ y_pred = 0
106
+ for j in [0,1,2,3]:
107
+ rotated = np.rot90(image, j)
108
+ pred = model.predict(rotated.reshape(1, 128, 128, 1)).reshape(128 ,128)
109
+ pred = np.rot90(pred, -j)
110
+ y_pred += pred / 4
111
+
112
+ # Thresholding
113
+ y_pred = np.where(y_pred > 0.4, y_pred, 0)
114
+
115
+ with col2:
116
  plot_prediction(y_pred)
117
 
118
  # ccd = CCDData(pred, unit="adu", wcs=wcs)