Plsek commited on
Commit
3190548
·
1 Parent(s): 44cb863

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -8
app.py CHANGED
@@ -3,7 +3,7 @@ import numpy as np
3
  import matplotlib.pyplot as plt
4
  from astropy.io import fits
5
  from astropy.wcs import WCS
6
- from astropy.nddata import Cutout2D
7
  from tensorflow.keras.models import load_model
8
 
9
  st.set_option('deprecation.showPyplotGlobalUse', False)
@@ -13,7 +13,14 @@ st.title("Cavity Detection Tool")
13
  model = load_model("CADET.hdf5")
14
 
15
  # Define function to plot the uploaded image
16
- def plot_image(image_array, pred):
 
 
 
 
 
 
 
17
  plt.figure(figsize=(10, 5))
18
  plt.subplot(1, 2, 1)
19
  plt.imshow(image_array, origin="lower")
@@ -57,12 +64,17 @@ if uploaded_file is not None:
57
  with fits.open(uploaded_file) as hdul:
58
  data = hdul[0].data
59
  wcs = WCS(hdul[0].header)
60
- data, wcs = cut(data, wcs, scale=scale)
61
 
62
- image_data = np.log10(data+1)
63
- pred = model.predict(image_data.reshape(1, 128, 128, 1)).reshape(128 ,128)
 
 
 
 
 
 
64
 
65
- ccd = CCDData(pred, unit="adu", wcs=wcs)
66
- ccd.write(f"predicted.fits", overwrite=True)
67
 
68
- plot_image(image_data, pred)
 
3
  import matplotlib.pyplot as plt
4
  from astropy.io import fits
5
  from astropy.wcs import WCS
6
+ from astropy.nddata import Cutout2D, CCDData
7
  from tensorflow.keras.models import load_model
8
 
9
  st.set_option('deprecation.showPyplotGlobalUse', False)
 
13
  model = load_model("CADET.hdf5")
14
 
15
  # Define function to plot the uploaded image
16
+ def plot_image(image_array, scale):
17
+ plt.figure(figsize=(5, 5))
18
+ # plt.subplot(1, 2, 1)
19
+ plt.imshow(image_array, origin="lower")
20
+ plt.axis('off')
21
+
22
+ # Define function to plot the prediction
23
+ def plot_prediction(image_array, pred):
24
  plt.figure(figsize=(10, 5))
25
  plt.subplot(1, 2, 1)
26
  plt.imshow(image_array, origin="lower")
 
64
  with fits.open(uploaded_file) as hdul:
65
  data = hdul[0].data
66
  wcs = WCS(hdul[0].header)
 
67
 
68
+ plot_image(np.log10(data+1), scale)
69
+
70
+ if st.button('Detect Cavity'):
71
+ data, wcs = cut(data, wcs, scale=scale)
72
+
73
+ image_data = np.log10(data+1)
74
+
75
+ pred = model.predict(image_data.reshape(1, 128, 128, 1)).reshape(128 ,128)
76
 
77
+ # ccd = CCDData(pred, unit="adu", wcs=wcs)
78
+ # ccd.write(f"predicted.fits", overwrite=True)
79
 
80
+ plot_prediction(image_data, pred)