import streamlit as st
import numpy as np
import os
import pathlib
from inference import infer, InferenceModel
from text import intro

# -----------------------------------------------------------------------------
# class SatvisionDemoApp
#
# Directory Structure: base-directory/MOD09GA/year
#                                     MOD09GQ/year
#                                     MYD09GA/year
#                                     MYD09GQ/year
#
# -----------------------------------------------------------------------------
class SatvisionDemoApp:

    # -------------------------------------------------------------------------
    # __init__
    # -------------------------------------------------------------------------
    def __init__(self):

        self.thumbnail_dir = pathlib.Path('data/thumbnails')
        self.image_dir = pathlib.Path('data/images')
        print(self.thumbnail_dir)
        self.thumbnail_files = sorted(list(self.thumbnail_dir.glob('sv-*.png')))
        self.image_files = sorted(list(self.image_dir.glob('sv-*.npy')))
        print(list(self.image_files))
        self.thumbnail_names = [str(tn_path.name) for tn_path in self.thumbnail_files]
        print(self.thumbnail_names)

        self.inferenceModel =  InferenceModel()

    # -------------------------------------------------------------------------
    # render_sidebar
    # -------------------------------------------------------------------------
    def render_sidebar(self):

        st.sidebar.header("Select an Image")

        for index, thumbnail in enumerate(self.thumbnail_names):

            thumbnail_path = self.thumbnail_dir / thumbnail

            # thumbnail_arr = np.load(thumbnail_path)
            print(str(thumbnail_path))

            st.sidebar.image(str(thumbnail_path), use_column_width=True, caption=thumbnail)

    # -------------------------------------------------------------------------
    # render_main_app
    # -------------------------------------------------------------------------
    def render_main_app(self):

        st.title("Satvision-Base Demo")

        st.header("Image Reconstruction Process")
        selected_image_index = st.sidebar.selectbox(
                    "Select an Image",
                    self.thumbnail_names)
        print(selected_image_index)

        selected_image = self.load_selected_image(selected_image_index)

        image, masked_input, output = self.inferenceModel.infer(selected_image)

        col1, col2, col3 = st.columns(3, gap="large")

        # Display the selected image with a title three times side-by-side

        with col1:
            st.image(image, use_column_width=True, caption="Input")

        with col2:
            st.image(masked_input, use_column_width=True, caption="Input Masked")

        with col3:
            st.image(output, use_column_width=True, caption="Reconstruction")

        st.markdown(intro)

        st.image('data/figures/reconstruction.png')

    # -------------------------------------------------------------------------
    # load_selected_image
    # -------------------------------------------------------------------------
    def load_selected_image(self, image_name):

        # Load the selected image using NumPy (replace this with your image loading code)
        image_name = image_name.replace('.png', '.npy')

        image = np.load(self.image_dir / image_name)
        image = np.moveaxis(image, 0, 2)
        return image

# -----------------------------------------------------------------------------
# main
# -----------------------------------------------------------------------------
def main():

    app = SatvisionDemoApp()

    app.render_main_app()

    app.render_sidebar()

if __name__ == "__main__":

    main()