Plsek commited on
Commit
29cd1e9
·
1 Parent(s): 5332a80

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -5
app.py CHANGED
@@ -56,6 +56,13 @@ def plot_prediction(pred):
56
  plt.axis('off')
57
  with colB: st.pyplot()
58
 
 
 
 
 
 
 
 
59
  # Cut input image and rebin it to 128x128 pixels
60
  def cut(data0, wcs0, scale=1):
61
  shape = data0.shape[0]
@@ -105,8 +112,6 @@ def decompose_cavity(pred, th2=0.7, amin=10):
105
  cavities.append(img)
106
 
107
  return cavities
108
-
109
- np.save("pred.npy", np.zeros((128,128)))
110
 
111
  # If file is uploaded, read in the data and plot it
112
  if uploaded_file is not None:
@@ -157,9 +162,9 @@ if uploaded_file is not None:
157
 
158
  np.save("pred.npy", y_pred)
159
 
160
- y_pred = np.load("pred.npy")
 
161
  y_pred = np.where(y_pred > threshold, y_pred, 0)
162
-
163
  np.save("thresh.npy", y_pred)
164
 
165
  plot_prediction(y_pred)
@@ -177,7 +182,9 @@ if uploaded_file is not None:
177
  for i, cav in enumerate(cavs):
178
  ccd = CCDData(cav, unit="adu", wcs=wcs)
179
  ccd.write(f"predicted_{i+1}.fits", overwrite=True)
180
-
 
 
181
  # with col4:
182
  # pass
183
  # st.markdown("""<style>[data-baseweb="select"] {margin-top: 16px;}</style>""", unsafe_allow_html=True)
 
56
  plt.axis('off')
57
  with colB: st.pyplot()
58
 
59
+ # Define function to plot the decomposed prediction
60
+ def plot_prediction(pred):
61
+ plt.figure(figsize=(4, 4))
62
+ plt.imshow(pred, origin="lower")
63
+ plt.axis('off')
64
+ with colC: st.pyplot()
65
+
66
  # Cut input image and rebin it to 128x128 pixels
67
  def cut(data0, wcs0, scale=1):
68
  shape = data0.shape[0]
 
112
  cavities.append(img)
113
 
114
  return cavities
 
 
115
 
116
  # If file is uploaded, read in the data and plot it
117
  if uploaded_file is not None:
 
162
 
163
  np.save("pred.npy", y_pred)
164
 
165
+ try: y_pred = np.load("pred.npy")
166
+ except: y_pred = np.zeros((128,128))
167
  y_pred = np.where(y_pred > threshold, y_pred, 0)
 
168
  np.save("thresh.npy", y_pred)
169
 
170
  plot_prediction(y_pred)
 
182
  for i, cav in enumerate(cavs):
183
  ccd = CCDData(cav, unit="adu", wcs=wcs)
184
  ccd.write(f"predicted_{i+1}.fits", overwrite=True)
185
+
186
+ plot_decomposed(np.zeros((128,128)))
187
+
188
  # with col4:
189
  # pass
190
  # st.markdown("""<style>[data-baseweb="select"] {margin-top: 16px;}</style>""", unsafe_allow_html=True)