Update app.py
Browse files
app.py
CHANGED
@@ -9,8 +9,12 @@ from astropy.wcs import WCS
|
|
9 |
from astropy.nddata import Cutout2D, CCDData
|
10 |
|
11 |
# Tensorflow
|
12 |
-
from tensorflow.keras.models import load_model
|
13 |
-
model = load_model("CADET.hdf5")
|
|
|
|
|
|
|
|
|
14 |
|
15 |
# Streamlit
|
16 |
import streamlit as st
|
@@ -75,15 +79,16 @@ if uploaded_file is not None:
|
|
75 |
# Make two columns
|
76 |
col1, col2, col3, col4 = st.columns(4)
|
77 |
col1.subheader("Input image")
|
78 |
-
col3.subheader("
|
79 |
|
80 |
# Add a slider to change the scale
|
81 |
with col1:
|
82 |
smooth = st.button("Smooth")
|
83 |
|
84 |
with col2:
|
85 |
-
st.markdown("""<style>[data-baseweb="select"] {margin-top: -52px;}</style>""", unsafe_allow_html=True)
|
86 |
-
|
|
|
87 |
max_scale = int(data.shape[0] // 128)
|
88 |
# scale = st.slider("Scale", 1, max_scale, 1, 1)
|
89 |
scale = int(st.selectbox('Scale:',[i+1 for i in range(max_scale)], label_visibility="hidden"))
|
@@ -92,6 +97,7 @@ if uploaded_file is not None:
|
|
92 |
detect = st.button('Detect cavities')
|
93 |
|
94 |
with col4:
|
|
|
95 |
download = st.button('Download')
|
96 |
|
97 |
|
@@ -117,9 +123,9 @@ if uploaded_file is not None:
|
|
117 |
|
118 |
plot_prediction(y_pred)
|
119 |
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
|
|
9 |
from astropy.nddata import Cutout2D, CCDData
|
10 |
|
11 |
# Tensorflow
|
12 |
+
# from tensorflow.keras.models import load_model
|
13 |
+
# model = load_model("CADET.hdf5")
|
14 |
+
|
15 |
+
# HuggingFace
|
16 |
+
from huggingface_hub import from_pretrained_keras
|
17 |
+
model = from_pretrained_keras("Plsek/CADET-v1")
|
18 |
|
19 |
# Streamlit
|
20 |
import streamlit as st
|
|
|
79 |
# Make two columns
|
80 |
col1, col2, col3, col4 = st.columns(4)
|
81 |
col1.subheader("Input image")
|
82 |
+
col3.subheader("Prediction")
|
83 |
|
84 |
# Add a slider to change the scale
|
85 |
with col1:
|
86 |
smooth = st.button("Smooth")
|
87 |
|
88 |
with col2:
|
89 |
+
# st.markdown("""<style>[data-baseweb="select"] {margin-top: -52px;}</style>""", unsafe_allow_html=True)
|
90 |
+
st.markdown("""<style>[data-baseweb="select"] {margin-top: 52px;}</style>""", unsafe_allow_html=True)
|
91 |
+
|
92 |
max_scale = int(data.shape[0] // 128)
|
93 |
# scale = st.slider("Scale", 1, max_scale, 1, 1)
|
94 |
scale = int(st.selectbox('Scale:',[i+1 for i in range(max_scale)], label_visibility="hidden"))
|
|
|
97 |
detect = st.button('Detect cavities')
|
98 |
|
99 |
with col4:
|
100 |
+
st.markdown("""<style>[data-baseweb="select"] {margin-top: 52px;}</style>""", unsafe_allow_html=True)
|
101 |
download = st.button('Download')
|
102 |
|
103 |
|
|
|
123 |
|
124 |
plot_prediction(y_pred)
|
125 |
|
126 |
+
if download:
|
127 |
+
ccd = CCDData(y_pred, unit="adu", wcs=wcs)
|
128 |
+
ccd.write("predicted.fits", overwrite=True)
|
129 |
+
with open('predicted.fits', 'rb') as f:
|
130 |
+
data = f.read()
|
131 |
+
st.download_button(label="Download", data=data, file_name="predicted.fits", mime="application/octet-stream")
|