File size: 4,986 Bytes
9fcd62f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
408afad
9fcd62f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e414eda
 
9fcd62f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
09edb35
 
 
 
 
9fcd62f
09edb35
9fcd62f
09edb35
 
9fcd62f
 
09edb35
 
9fcd62f
 
 
 
09edb35
 
 
 
 
9fcd62f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69b90e4
 
 
9fcd62f
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import streamlit as st
from streamlit_folium import st_folium
import folium
import logging
import sys
import hydra
from plot_functions import *
import hydra

import torch
from model import LitUnsupervisedSegmenter
from helper import inference_on_location_and_month, inference_on_location

DEFAULT_LATITUDE = 48.81
DEFAULT_LONGITUDE = 2.98
DEFAULT_ZOOM = 5

MIN_YEAR = 2018
MAX_YEAR = 2024

FOLIUM_WIDTH = 925
FOLIUM_HEIGHT = 300


st.set_page_config(layout="wide")
@st.cache_resource 
def init_cfg(cfg_name):
    hydra.initialize(config_path="configs", job_name="corine")
    return hydra.compose(config_name=cfg_name)

@st.cache_resource 
def init_app(cfg_name) -> LitUnsupervisedSegmenter:
    file_handler = logging.FileHandler(filename='biomap.log')
    stdout_handler = logging.StreamHandler(stream=sys.stdout)
    handlers = [file_handler, stdout_handler]

    logging.basicConfig(handlers=handlers, encoding='utf-8', level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
    # # Initialize hydra with configs
    # GlobalHydra.instance().clear()
    
    cfg = init_cfg(cfg_name)
    logging.info(f"config : {cfg}")
    nbclasses = cfg.dir_dataset_n_classes
    model = LitUnsupervisedSegmenter(nbclasses, cfg)
    model = model.cpu()
    logging.info(f"Model Initialiazed")
    
    model_path = "biomap/checkpoint/model/model.pt"
    saved_state_dict = torch.load(model_path, map_location=torch.device("cpu"))
    logging.info(f"Model weights Loaded")
    model.load_state_dict(saved_state_dict)
    return model

def app(model):
    if "infered" not in st.session_state:
        st.session_state["infered"] = False

    st.markdown("<h1 style='text-align: center;'>🐢 Biomap by Ekimetrics 🐢</h1>", unsafe_allow_html=True)
    st.markdown("<h2 style='text-align: center;'>Estimate Biodiversity in the world with the help of land cover.</h2>", unsafe_allow_html=True)
    st.markdown("<p style='text-align: center;'>The segmentation model is an association of UNet and DinoV1 trained on the dataset CORINE. Land use is divided into 6 differents classes : Each class is assigned a GBS score from 0 to 1</p>", unsafe_allow_html=True)
    st.markdown("<p style='text-align: center;'>Buildings : 0.1 | Infrastructure : 0.1 | Cultivation : 0.4 | Wetland : 0.9 | Water : 0.9 | Natural green : 1 </p>", unsafe_allow_html=True)
    st.markdown("<p style='text-align: center;'>The score is then average on the full image.</p>", unsafe_allow_html=True)
    
    col_1, col_2 = st.columns([0.5,0.5])
    with col_1:
        m = folium.Map(location=[DEFAULT_LATITUDE, DEFAULT_LONGITUDE], zoom_start=DEFAULT_ZOOM)
        m.add_child(folium.LatLngPopup())
        f_map = st_folium(m, width=FOLIUM_WIDTH, height=FOLIUM_HEIGHT)
    
    selected_latitude = DEFAULT_LATITUDE
    selected_longitude = DEFAULT_LONGITUDE

    if f_map.get("last_clicked"):
        selected_latitude = f_map["last_clicked"]["lat"]
        selected_longitude = f_map["last_clicked"]["lng"]

    with col_2:
        tabs1, tabs2 = st.tabs(["TimeLapse", "Single Image"])
        with tabs1:
            col_tab1_1, col_tab1_2 = st.columns(2)
            with col_tab1_1:
                lat = st.text_input("lattitude", value=selected_latitude)
            with col_tab1_2:
                long = st.text_input("longitude", value=selected_longitude)

            col_tab1_11, col_tab1_22 = st.columns(2)
            years = list(range(MIN_YEAR, MAX_YEAR, 1))
            with col_tab1_11:
                start_date = st.selectbox("Start date", years)

            end_years = [year for year in years if year > start_date]
            with col_tab1_22:
                end_date = st.selectbox("End date", end_years) 
        
            segment_interval = st.radio("Interval of time between two segmentation", options=['month','2months', 'year'],horizontal=True)
            submit = st.button("Predict TimeLapse", use_container_width=True)
        with tabs2:
            col_tab2_1, col_tab2_2 = st.columns(2)
            with col_tab2_1:
                lat = st.text_input("lat.", value=selected_latitude)
            with col_tab2_2:
                long = st.text_input("long.", value=selected_longitude)

            date = st.text_input("date", "2021-01-01", placeholder="2021-01-01")
        
            submit2 = st.button("Predict Single Image", use_container_width=True)


    if submit:
        fig = inference_on_location(model, lat, long, start_date, end_date, segment_interval)
        st.session_state["infered"] = True
        st.session_state["previous_fig"] = fig

    if submit2:
        fig = inference_on_location_and_month(model, lat, long, date)
        st.session_state["infered"] = True
        st.session_state["previous_fig"] = fig

    if st.session_state["infered"]:
        st.plotly_chart(st.session_state["previous_fig"], use_container_width=True)
        
    
if __name__ == "__main__":
    model = init_app("my_train_config.yml")
    app(model)