File size: 3,662 Bytes
ab687e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import streamlit as st
import numpy as np
import os
import pathlib
from inference import infer, InferenceModel

# -----------------------------------------------------------------------------
# class SatvisionDemoApp
#
# Directory Structure: base-directory/MOD09GA/year
#                                     MOD09GQ/year
#                                     MYD09GA/year
#                                     MYD09GQ/year
#
# -----------------------------------------------------------------------------
class SatvisionDemoApp:

    # -------------------------------------------------------------------------
    # __init__
    # -------------------------------------------------------------------------
    def __init__(self):

        self.thumbnail_dir = pathlib.Path('data/thumbnails')
        self.image_dir = pathlib.Path('data/images')
        print(self.thumbnail_dir)
        self.thumbnail_files = sorted(list(self.thumbnail_dir.glob('sv-*.png')))
        self.image_files = sorted(list(self.image_dir.glob('sv-*.npy')))
        print(list(self.image_files))
        self.thumbnail_names = [str(tn_path.name) for tn_path in self.thumbnail_files]
        print(self.thumbnail_names)

        self.inferenceModel =  InferenceModel()

    # -------------------------------------------------------------------------
    # render_sidebar
    # -------------------------------------------------------------------------
    def render_sidebar(self):

        st.sidebar.header("Select an Image")

        for index, thumbnail in enumerate(self.thumbnail_names):

            thumbnail_path = self.thumbnail_dir / thumbnail

            # thumbnail_arr = np.load(thumbnail_path)
            print(str(thumbnail_path))

            st.sidebar.image(str(thumbnail_path), use_column_width=True, caption=thumbnail)

    # -------------------------------------------------------------------------
    # render_main_app
    # -------------------------------------------------------------------------
    def render_main_app(self):

        st.title("Satvision-Base Demo")

        st.header("Image Reconstruction Process")
        selected_image_index = st.sidebar.selectbox(
                    "Select an Image",
                    self.thumbnail_names)
        print(selected_image_index)

        selected_image = self.load_selected_image(selected_image_index)

        image, masked_input, output = self.inferenceModel.infer(selected_image)

        col1, col2, col3 = st.columns(3, gap="large")

        # Display the selected image with a title three times side-by-side

        with col1:
            st.image(image, use_column_width=True, caption="Input")

        with col2:
            st.image(masked_input, use_column_width=True, caption="Input Masked")

        with col3:
            st.image(output, use_column_width=True, caption="Reconstruction")

    # -------------------------------------------------------------------------
    # load_selected_image
    # -------------------------------------------------------------------------
    def load_selected_image(self, image_name):

        # Load the selected image using NumPy (replace this with your image loading code)
        image_name = image_name.replace('.png', '.npy')

        image = np.load(self.image_dir / image_name)
        image = np.moveaxis(image, 0, 2)
        return image

# -----------------------------------------------------------------------------
# main
# -----------------------------------------------------------------------------
def main():

    app = SatvisionDemoApp()

    app.render_main_app()

    app.render_sidebar()

if __name__ == "__main__":

    main()