Plsek commited on
Commit
d01a85f
·
1 Parent(s): 31764f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -10
app.py CHANGED
@@ -7,6 +7,8 @@ from matplotlib.patches import Rectangle
7
  from astropy.io import fits
8
  from astropy.wcs import WCS
9
  from astropy.nddata import Cutout2D, CCDData
 
 
10
 
11
  # HuggingFace
12
  from huggingface_hub import from_pretrained_keras
@@ -25,16 +27,23 @@ st.markdown("Cavity Detection Tool (CADET) is a machine learning pipeline traine
25
  uploaded_file = st.file_uploader("Choose a FITS file", type=['fits'])
26
 
27
  # Define function to plot the uploaded image
28
- def plot_image(image_array, scale):
29
  plt.figure(figsize=(4, 4))
30
 
31
- x0 = image_array.shape[0] // 2 - scale * 128 / 2
32
- plt.imshow(image_array, origin="lower")
33
  plt.gca().add_patch(Rectangle((x0, x0), scale*128, scale*128, linewidth=1, edgecolor='w', facecolor='none'))
34
 
35
  plt.axis('off')
36
  with colA: st.pyplot()
37
 
 
 
 
 
 
 
 
38
  # Define function to plot the prediction
39
  def plot_prediction(pred):
40
  plt.figure(figsize=(4, 4))
@@ -83,7 +92,7 @@ if uploaded_file is not None:
83
 
84
  with col2:
85
  # st.markdown("""<style>[data-baseweb="select"] {margin-top: -52px;}</style>""", unsafe_allow_html=True)
86
- st.markdown("""<style>[data-baseweb="select"] {margin-top: 32px;}</style>""", unsafe_allow_html=True)
87
 
88
  max_scale = int(data.shape[0] // 128)
89
  # scale = st.slider("Scale", 1, max_scale, 1, 1)
@@ -94,12 +103,12 @@ if uploaded_file is not None:
94
 
95
  colA, colB = st.columns(2)
96
 
97
- plot_image(np.log10(data+1), scale)
98
-
 
99
 
100
  if detect:
101
  data, wcs = cut(data, wcs, scale=scale)
102
-
103
  image = np.log10(data+1)
104
 
105
  y_pred = 0
@@ -112,14 +121,14 @@ if uploaded_file is not None:
112
  # Thresholding
113
  y_pred = np.where(y_pred > 0.4, y_pred, 0)
114
 
115
- # plot_prediction(y_pred)
116
 
117
  ccd = CCDData(y_pred, unit="adu", wcs=wcs)
118
  ccd.write("predicted.fits", overwrite=True)
119
  with open('predicted.fits', 'rb') as f:
120
  data = f.read()
121
 
122
- # with col4:
123
- # st.markdown("""<style>[data-baseweb="select"] {margin-top: 32px;}</style>""", unsafe_allow_html=True)
124
  # # download = st.button('Download')
125
  # download = st.download_button(label="Download", data=data, file_name="predicted.fits", mime="application/octet-stream")
 
7
  from astropy.io import fits
8
  from astropy.wcs import WCS
9
  from astropy.nddata import Cutout2D, CCDData
10
+ from astropy.convolution import Gaussian2DKernel as Gauss
11
+ from astropy.convolution import convolve
12
 
13
  # HuggingFace
14
  from huggingface_hub import from_pretrained_keras
 
27
  uploaded_file = st.file_uploader("Choose a FITS file", type=['fits'])
28
 
29
  # Define function to plot the uploaded image
30
+ def plot_image(image, scale):
31
  plt.figure(figsize=(4, 4))
32
 
33
+ x0 = image.shape[0] // 2 - scale * 128 / 2
34
+ plt.imshow(image, origin="lower")
35
  plt.gca().add_patch(Rectangle((x0, x0), scale*128, scale*128, linewidth=1, edgecolor='w', facecolor='none'))
36
 
37
  plt.axis('off')
38
  with colA: st.pyplot()
39
 
40
+ # Define function to smooth image
41
+ def smooth_image(image, scale):
42
+ smoothed = convolve(image, boundary = "wrap", nan_treatment="interpolate",
43
+ kernel = Gauss(x_stddev = 2, y_stddev = 2))
44
+
45
+ return smoothed
46
+
47
  # Define function to plot the prediction
48
  def plot_prediction(pred):
49
  plt.figure(figsize=(4, 4))
 
92
 
93
  with col2:
94
  # st.markdown("""<style>[data-baseweb="select"] {margin-top: -52px;}</style>""", unsafe_allow_html=True)
95
+ st.markdown("""<style>[data-baseweb="select"] {margin-top: 22px;}</style>""", unsafe_allow_html=True)
96
 
97
  max_scale = int(data.shape[0] // 128)
98
  # scale = st.slider("Scale", 1, max_scale, 1, 1)
 
103
 
104
  colA, colB = st.columns(2)
105
 
106
+ image = np.log10(data+1)
107
+ if smooth: image = smooth_image(image, scale)
108
+ plot_image(image, scale)
109
 
110
  if detect:
111
  data, wcs = cut(data, wcs, scale=scale)
 
112
  image = np.log10(data+1)
113
 
114
  y_pred = 0
 
121
  # Thresholding
122
  y_pred = np.where(y_pred > 0.4, y_pred, 0)
123
 
124
+ plot_prediction(y_pred)
125
 
126
  ccd = CCDData(y_pred, unit="adu", wcs=wcs)
127
  ccd.write("predicted.fits", overwrite=True)
128
  with open('predicted.fits', 'rb') as f:
129
  data = f.read()
130
 
131
+ with col4:
132
+ st.markdown("""<style>[data-baseweb="select"] {margin-top: 32px;}</style>""", unsafe_allow_html=True)
133
  # # download = st.button('Download')
134
  # download = st.download_button(label="Download", data=data, file_name="predicted.fits", mime="application/octet-stream")