File size: 5,972 Bytes
9fcd62f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
408afad
9fcd62f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
709a47d
 
 
 
9fcd62f
 
e414eda
 
9fcd62f
3006f1e
4debc65
709a47d
 
 
 
9fcd62f
709a47d
 
 
 
 
 
 
 
86735e0
 
 
 
 
 
709a47d
 
 
 
 
 
 
86735e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9fcd62f
86735e0
 
709a47d
86735e0
 
 
 
09edb35
 
86735e0
709a47d
09edb35
86735e0
709a47d
9fcd62f
709a47d
 
9fcd62f
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import streamlit as st
from streamlit_folium import st_folium
import folium
import logging
import sys
import hydra
from plot_functions import *
import hydra

import torch
from model import LitUnsupervisedSegmenter
from helper import inference_on_location_and_month, inference_on_location

DEFAULT_LATITUDE = 48.81
DEFAULT_LONGITUDE = 2.98
DEFAULT_ZOOM = 5

MIN_YEAR = 2018
MAX_YEAR = 2024

FOLIUM_WIDTH = 925
FOLIUM_HEIGHT = 300


st.set_page_config(layout="wide")
@st.cache_resource 
def init_cfg(cfg_name):
    hydra.initialize(config_path="configs", job_name="corine")
    return hydra.compose(config_name=cfg_name)

@st.cache_resource 
def init_app(cfg_name) -> LitUnsupervisedSegmenter:
    file_handler = logging.FileHandler(filename='biomap.log')
    stdout_handler = logging.StreamHandler(stream=sys.stdout)
    handlers = [file_handler, stdout_handler]

    logging.basicConfig(handlers=handlers, encoding='utf-8', level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
    # # Initialize hydra with configs
    # GlobalHydra.instance().clear()
    
    cfg = init_cfg(cfg_name)
    logging.info(f"config : {cfg}")
    nbclasses = cfg.dir_dataset_n_classes
    model = LitUnsupervisedSegmenter(nbclasses, cfg)
    model = model.cpu()
    logging.info(f"Model Initialiazed")
    
    model_path = "biomap/checkpoint/model/model.pt"
    saved_state_dict = torch.load(model_path, map_location=torch.device("cpu"))
    logging.info(f"Model weights Loaded")
    model.load_state_dict(saved_state_dict)
    return model

def app(model):
    if "infered" not in st.session_state:
        st.session_state["infered"] = False
    if "submit" not in st.session_state:
        st.session_state["submit"] = False
    if "submit2" not in st.session_state:
        st.session_state["submit2"] = False

    st.markdown("<h1 style='text-align: center;'>🐢 Biomap by Ekimetrics 🐢</h1>", unsafe_allow_html=True)
    st.markdown("<h2 style='text-align: center;'>Estimate Biodiversity in the world with the help of land cover.</h2>", unsafe_allow_html=True)
    st.markdown("<p style='text-align: center;'>The segmentation model is an association of UNet and DinoV1 trained on the dataset CORINE. Land use is divided into 6 differents classes : Each class is assigned a GBS score from 0 to 1</p>", unsafe_allow_html=True)
    st.markdown("<p style='text-align: center;'>Buildings : 0.1 | Infrastructure : 0.1 | Cultivation : 0.4 | Wetland : 0.9 | Water : 0.9 | Natural green : 1 </p>", unsafe_allow_html=True)
    st.markdown("<p style='text-align: center;'>The score is then averaged on the full image.</p>", unsafe_allow_html=True)
    
    if st.session_state["submit"]:
            fig = inference_on_location(model, st.session_state["lat"], st.session_state["long"], st.session_state["start_date"], st.session_state["end_date"], st.session_state["segment_interval"])
            st.session_state["infered"] = True
            st.session_state["previous_fig"] = fig
    
    if st.session_state["submit2"]:
            fig = inference_on_location_and_month(model, st.session_state["lat_2"], st.session_state["long_2"], st.session_state["date_2"])
            st.session_state["infered"] = True
            st.session_state["previous_fig"] = fig

    if st.session_state["infered"]:
            st.plotly_chart(st.session_state["previous_fig"], use_container_width=True)
        
    col_1, col_2 = st.columns([0.5, 0.5])
    with col_1:
        m = folium.Map(location=[DEFAULT_LATITUDE, DEFAULT_LONGITUDE], zoom_start=DEFAULT_ZOOM)
        m.add_child(folium.LatLngPopup())
        f_map = st_folium(m, width=FOLIUM_WIDTH, height=FOLIUM_HEIGHT)
    
        selected_latitude = DEFAULT_LATITUDE
        selected_longitude = DEFAULT_LONGITUDE

        if f_map.get("last_clicked"):
            selected_latitude = f_map["last_clicked"]["lat"]
            selected_longitude = f_map["last_clicked"]["lng"]

    with col_2:
        tabs1, tabs2 = st.tabs(["TimeLapse", "Single Image"])
        with tabs1:
            submit = st.button("Predict TimeLapse", use_container_width=True, type="primary")
            st.session_state["submit"] = submit

            col_tab1_1, col_tab1_2 = st.columns(2)
            with col_tab1_1:
                lat = st.text_input("latitude", value=selected_latitude)
                st.session_state["lat"] = lat
            with col_tab1_2:
                long = st.text_input("longitude", value=selected_longitude)
                st.session_state["long"] = long

            col_tab1_11, col_tab1_22 = st.columns(2)
            years = list(range(MIN_YEAR, MAX_YEAR, 1))
            with col_tab1_11:
                start_date = st.selectbox("Start date", years)
                st.session_state["start_date"] = start_date

            end_years = [year for year in years if year > start_date]
            with col_tab1_22:
                end_date = st.selectbox("End date", end_years) 
                st.session_state["end_date"] = end_date
        
            segment_interval = st.radio("Interval of time between two segmentation", options=['month','2months', 'year'],horizontal=True)
            st.session_state["segment_interval"] = segment_interval

        with tabs2:
            submit2 = st.button("Predict Single Image", use_container_width=True, type="primary")
            st.session_state["submit2"] = submit2
    
            col_tab2_1, col_tab2_2 = st.columns(2)
            with col_tab2_1:
                lat_2 = st.text_input("lat.", value=selected_latitude)
                st.session_state["lat_2"] = lat_2
            with col_tab2_2:
                long_2 = st.text_input("long.", value=selected_longitude)
                st.session_state["long_2"] = long_2

            date_2 = st.text_input("date", "2021-01-01", placeholder="2021-01-01")
            st.session_state["date_2"] = date_2


if __name__ == "__main__":
    model = init_app("my_train_config.yml")
    app(model)