Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -184,65 +184,65 @@ with col: | |
| 184 | 
             
            if uploaded_file is not None:
         | 
| 185 | 
             
                data, wcs = load_file(uploaded_file)
         | 
| 186 |  | 
| 187 | 
            -
            if "data" not in locals():
         | 
| 188 | 
            -
                data = np.zeros((128,128))
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 189 |  | 
| 190 | 
            -
             | 
| 191 | 
            -
             | 
| 192 | 
            -
             | 
| 193 | 
            -
             | 
| 194 | 
            -
             | 
| 195 | 
            -
            col6.subheader("")
         | 
| 196 | 
            -
             | 
| 197 | 
            -
            with col1:
         | 
| 198 | 
            -
                st.markdown("""<style>[data-baseweb="select"] {margin-top: -46px;}</style>""", unsafe_allow_html=True)
         | 
| 199 | 
            -
                max_scale = int(data.shape[0] // 128)
         | 
| 200 | 
            -
                scale = st.selectbox('Scale:',[f"{(i+1)*128}x{(i+1)*128}" for i in range(max_scale)], label_visibility="hidden", on_change=change_scale)
         | 
| 201 | 
            -
                scale = int(scale.split("x")[0]) // 128
         | 
| 202 | 
            -
             | 
| 203 | 
            -
            # Detect button
         | 
| 204 | 
            -
            with col3: detect = st.button('Detect', key="detect")
         | 
| 205 | 
            -
             | 
| 206 | 
            -
            # Threshold slider
         | 
| 207 | 
            -
            with col4:
         | 
| 208 | 
            -
                st.markdown("")
         | 
| 209 | 
            -
                # st.markdown("""<style>[data-baseweb="select"] {margin-top: -36px;}</style>""", unsafe_allow_html=True)
         | 
| 210 | 
            -
                threshold = st.slider("Threshold", 0.0, 1.0, 0.0, 0.05, key="threshold") #, label_visibility="hidden")
         | 
| 211 |  | 
| 212 | 
            -
            #  | 
| 213 | 
            -
            with  | 
| 214 |  | 
| 215 | 
            -
            #  | 
| 216 | 
            -
             | 
| 217 | 
            -
             | 
| 218 | 
            -
             | 
| 219 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 220 |  | 
| 221 | 
            -
            if detect or threshold:
         | 
| 222 | 
            -
            # if st.session_state.get("detect", True):
         | 
| 223 | 
            -
             | 
| 224 |  | 
| 225 | 
            -
             | 
| 226 |  | 
| 227 | 
            -
             | 
| 228 |  | 
| 229 | 
            -
             | 
| 230 | 
            -
             | 
| 231 |  | 
| 232 | 
            -
             | 
| 233 |  | 
| 234 | 
            -
             | 
| 235 | 
            -
             | 
| 236 | 
            -
             | 
| 237 | 
            -
             | 
| 238 |  | 
| 239 | 
            -
             | 
| 240 |  | 
| 241 | 
            -
             | 
| 242 | 
            -
             | 
| 243 | 
            -
             | 
| 244 |  | 
| 245 | 
            -
             | 
| 246 | 
            -
             | 
| 247 | 
            -
             | 
| 248 | 
            -
             | 
|  | |
| 184 | 
             
            if uploaded_file is not None:
         | 
| 185 | 
             
                data, wcs = load_file(uploaded_file)
         | 
| 186 |  | 
| 187 | 
            +
                # if "data" not in locals():
         | 
| 188 | 
            +
                #     data = np.zeros((128,128))
         | 
| 189 | 
            +
                    
         | 
| 190 | 
            +
                # Make six columns for buttons
         | 
| 191 | 
            +
                _, col1, col2, col3, col4, col5, col6, _ = st.columns([bordersize,0.5,0.5,0.5,0.5,0.5,0.5,bordersize])
         | 
| 192 | 
            +
                col1.subheader("Input image")
         | 
| 193 | 
            +
                col3.subheader("Prediction")
         | 
| 194 | 
            +
                col5.subheader("Decomposed")
         | 
| 195 | 
            +
                col6.subheader("")
         | 
| 196 |  | 
| 197 | 
            +
                with col1:
         | 
| 198 | 
            +
                    st.markdown("""<style>[data-baseweb="select"] {margin-top: -46px;}</style>""", unsafe_allow_html=True)
         | 
| 199 | 
            +
                    max_scale = int(data.shape[0] // 128)
         | 
| 200 | 
            +
                    scale = st.selectbox('Scale:',[f"{(i+1)*128}x{(i+1)*128}" for i in range(max_scale)], label_visibility="hidden", on_change=change_scale)
         | 
| 201 | 
            +
                    scale = int(scale.split("x")[0]) // 128
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 202 |  | 
| 203 | 
            +
                # Detect button
         | 
| 204 | 
            +
                with col3: detect = st.button('Detect', key="detect")
         | 
| 205 |  | 
| 206 | 
            +
                # Threshold slider
         | 
| 207 | 
            +
                with col4:
         | 
| 208 | 
            +
                    st.markdown("")
         | 
| 209 | 
            +
                    # st.markdown("""<style>[data-baseweb="select"] {margin-top: -36px;}</style>""", unsafe_allow_html=True)
         | 
| 210 | 
            +
                    threshold = st.slider("Threshold", 0.0, 1.0, 0.0, 0.05, key="threshold") #, label_visibility="hidden")
         | 
| 211 | 
            +
                    
         | 
| 212 | 
            +
                # Decompose button
         | 
| 213 | 
            +
                with col5: decompose = st.button('Decompose', key="decompose")
         | 
| 214 | 
            +
                    
         | 
| 215 | 
            +
                # Make two columns for plots
         | 
| 216 | 
            +
                _, colA, colB, colC, _ = st.columns([bordersize,1,1,1,bordersize])
         | 
| 217 | 
            +
                
         | 
| 218 | 
            +
                image = np.log10(data+1)
         | 
| 219 | 
            +
                plot_image(image, scale)
         | 
| 220 |  | 
| 221 | 
            +
            # if detect or threshold:
         | 
| 222 | 
            +
            # # if st.session_state.get("detect", True):
         | 
| 223 | 
            +
            #     y_pred, wcs = cut_n_predict(data, wcs, scale)
         | 
| 224 |  | 
| 225 | 
            +
            #     y_pred_th = np.where(y_pred > threshold, y_pred, 0)
         | 
| 226 |  | 
| 227 | 
            +
            #     plot_prediction(y_pred_th)
         | 
| 228 |  | 
| 229 | 
            +
            #     if decompose or st.session_state.get("download", False):            
         | 
| 230 | 
            +
            #         image_decomposed = decompose_cavity(y_pred_th)
         | 
| 231 |  | 
| 232 | 
            +
            #         plot_decomposed(image_decomposed)
         | 
| 233 |  | 
| 234 | 
            +
            #         with col6:
         | 
| 235 | 
            +
            #             st.markdown("<br style='margin:4px 0'>", unsafe_allow_html=True)
         | 
| 236 | 
            +
            #             # st.markdown("""<style>[data-baseweb="select"] {margin-top: 16px;}</style>""", unsafe_allow_html=True)
         | 
| 237 | 
            +
            #             fname = uploaded_file.name.strip(".fits")
         | 
| 238 |  | 
| 239 | 
            +
            #             # if st.session_state.get("download", False):
         | 
| 240 |  | 
| 241 | 
            +
            #             shutil.make_archive("predictions", 'zip', "predictions")
         | 
| 242 | 
            +
            #             with open('predictions.zip', 'rb') as f:
         | 
| 243 | 
            +
            #                 res = f.read()
         | 
| 244 |  | 
| 245 | 
            +
            #             download = st.download_button(label="Download", data=res, key="download", 
         | 
| 246 | 
            +
            #                                             file_name=f'{fname}_{int(scale*128)}.zip', 
         | 
| 247 | 
            +
            #                                             # disabled=st.session_state.get("disabled", True), 
         | 
| 248 | 
            +
            #                                             mime="application/octet-stream")
         | 
