Plsek commited on
Commit
778987b
·
1 Parent(s): 0e52fa4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -11
app.py CHANGED
@@ -9,8 +9,12 @@ from astropy.wcs import WCS
9
  from astropy.nddata import Cutout2D, CCDData
10
 
11
  # Tensorflow
12
- from tensorflow.keras.models import load_model
13
- model = load_model("CADET.hdf5")
 
 
 
 
14
 
15
  # Streamlit
16
  import streamlit as st
@@ -75,15 +79,16 @@ if uploaded_file is not None:
75
  # Make two columns
76
  col1, col2, col3, col4 = st.columns(4)
77
  col1.subheader("Input image")
78
- col3.subheader("CADET prediction")
79
 
80
  # Add a slider to change the scale
81
  with col1:
82
  smooth = st.button("Smooth")
83
 
84
  with col2:
85
- st.markdown("""<style>[data-baseweb="select"] {margin-top: -52px;}</style>""", unsafe_allow_html=True)
86
-
 
87
  max_scale = int(data.shape[0] // 128)
88
  # scale = st.slider("Scale", 1, max_scale, 1, 1)
89
  scale = int(st.selectbox('Scale:',[i+1 for i in range(max_scale)], label_visibility="hidden"))
@@ -92,6 +97,7 @@ if uploaded_file is not None:
92
  detect = st.button('Detect cavities')
93
 
94
  with col4:
 
95
  download = st.button('Download')
96
 
97
 
@@ -117,9 +123,9 @@ if uploaded_file is not None:
117
 
118
  plot_prediction(y_pred)
119
 
120
- if download:
121
- ccd = CCDData(y_pred, unit="adu", wcs=wcs)
122
- ccd.write("predicted.fits", overwrite=True)
123
- with open('predicted.fits', 'rb') as f:
124
- data = f.read()
125
- st.download_button(label="Download", data=data, file_name="predicted.fits", mime="application/octet-stream")
 
9
  from astropy.nddata import Cutout2D, CCDData
10
 
11
  # Tensorflow
12
+ # from tensorflow.keras.models import load_model
13
+ # model = load_model("CADET.hdf5")
14
+
15
+ # HuggingFace
16
+ from huggingface_hub import from_pretrained_keras
17
+ model = from_pretrained_keras("Plsek/CADET-v1")
18
 
19
  # Streamlit
20
  import streamlit as st
 
79
  # Make two columns
80
  col1, col2, col3, col4 = st.columns(4)
81
  col1.subheader("Input image")
82
+ col3.subheader("Prediction")
83
 
84
  # Add a slider to change the scale
85
  with col1:
86
  smooth = st.button("Smooth")
87
 
88
  with col2:
89
+ # st.markdown("""<style>[data-baseweb="select"] {margin-top: -52px;}</style>""", unsafe_allow_html=True)
90
+ st.markdown("""<style>[data-baseweb="select"] {margin-top: 52px;}</style>""", unsafe_allow_html=True)
91
+
92
  max_scale = int(data.shape[0] // 128)
93
  # scale = st.slider("Scale", 1, max_scale, 1, 1)
94
  scale = int(st.selectbox('Scale:',[i+1 for i in range(max_scale)], label_visibility="hidden"))
 
97
  detect = st.button('Detect cavities')
98
 
99
  with col4:
100
+ st.markdown("""<style>[data-baseweb="select"] {margin-top: 52px;}</style>""", unsafe_allow_html=True)
101
  download = st.button('Download')
102
 
103
 
 
123
 
124
  plot_prediction(y_pred)
125
 
126
+ if download:
127
+ ccd = CCDData(y_pred, unit="adu", wcs=wcs)
128
+ ccd.write("predicted.fits", overwrite=True)
129
+ with open('predicted.fits', 'rb') as f:
130
+ data = f.read()
131
+ st.download_button(label="Download", data=data, file_name="predicted.fits", mime="application/octet-stream")