File size: 4,835 Bytes
9fcd62f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
851dbaf
9fcd62f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import streamlit as st
from streamlit_folium import st_folium
import folium
import logging
import sys
import hydra
from plot_functions import *
import hydra

import torch
from model import LitUnsupervisedSegmenter
from helper import inference_on_location_and_month, inference_on_location

DEFAULT_LATITUDE = 48.81
DEFAULT_LONGITUDE = 2.98
DEFAULT_ZOOM = 5

MIN_YEAR = 2018
MAX_YEAR = 2024

FOLIUM_WIDTH = 925
FOLIUM_HEIGHT = 500


st.set_page_config(layout="wide")
@st.cache_resource 
def init_cfg(cfg_name):
    hydra.initialize(config_path="configs", job_name="corine")
    return hydra.compose(config_name=cfg_name)

@st.cache_resource 
def init_app(cfg_name) -> LitUnsupervisedSegmenter:
    file_handler = logging.FileHandler(filename='biomap.log')
    stdout_handler = logging.StreamHandler(stream=sys.stdout)
    handlers = [file_handler, stdout_handler]

    logging.basicConfig(handlers=handlers, encoding='utf-8', level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
    # # Initialize hydra with configs
    # GlobalHydra.instance().clear()
    
    cfg = init_cfg(cfg_name)
    logging.info(f"config : {cfg}")
    nbclasses = cfg.dir_dataset_n_classes
    model = LitUnsupervisedSegmenter(nbclasses, cfg)
    model = model.cpu()
    logging.info(f"Model Initialiazed")
    
    model_path = "biomap/checkpoint/model/model.pt"
    saved_state_dict = torch.load(model_path, map_location=torch.device("cpu"))
    logging.info(f"Model weights Loaded")
    model.load_state_dict(saved_state_dict)
    return model

def app(model):
    if "infered" not in st.session_state:
        st.session_state["infered"] = False

    st.markdown("<h1 style='text-align: center;'>🐢 Biomap by Ekimetrics 🐢</h1>", unsafe_allow_html=True)
    st.markdown("<h2 style='text-align: center;'>Estimate Biodiversity score in the world with the help of land use.</h2>", unsafe_allow_html=True)
    st.markdown("<p style='text-align: center;'>The segmentation is an association of UNet and DinoV1 trained on the dataset CORINE.</p>", unsafe_allow_html=True)
    st.markdown("<p style='text-align: center;'>Land use is divided into 6 differents classes :Each class is assigned a GBS score from 0 to 1</p>", unsafe_allow_html=True)
    st.markdown("<p style='text-align: center;'>Buildings : 0.1 | Infrastructure : 0.1 | Cultivation : 0.4 | Wetland : 0.9 | Water : 0.9 | Natural green : 1 </p>", unsafe_allow_html=True)
    st.markdown("<p style='text-align: center;'>The score is then average on the full image.</p>", unsafe_allow_html=True)
    

    
    col_1, col_2 = st.columns([0.5,0.5])
    with col_1:
        m = folium.Map(location=[DEFAULT_LATITUDE, DEFAULT_LONGITUDE], zoom_start=DEFAULT_ZOOM)

        # The code below will be responsible for displaying 
        # the popup with the latitude and longitude shown
        m.add_child(folium.LatLngPopup())
        f_map = st_folium(m, width=FOLIUM_WIDTH, height=FOLIUM_HEIGHT)

    
    selected_latitude = DEFAULT_LATITUDE
    selected_longitude = DEFAULT_LONGITUDE

    if f_map.get("last_clicked"):
        selected_latitude = f_map["last_clicked"]["lat"]
        selected_longitude = f_map["last_clicked"]["lng"]

    with col_2:
        tabs1, tabs2 = st.tabs(["TimeLapse", "Single Image"])
        with tabs1:
            lat = st.text_input("lattitude", value=selected_latitude)
            long = st.text_input("longitude", value=selected_longitude)


            years = list(range(MIN_YEAR, MAX_YEAR, 1))
            start_date = st.selectbox("Start date", years)

            end_years = [year for year in years if year > start_date]
            end_date = st.selectbox("End date", end_years) 
        
            segment_interval = st.radio("Interval of time between two segmentation", options=['month','2months', 'year'],horizontal=True)
            submit = st.button("Predict TimeLapse", use_container_width=True)
        with tabs2:
            lat = st.text_input("lat.", value=selected_latitude)
            long = st.text_input("long.", value=selected_longitude)

            date = st.text_input("date", "2021-01-01", placeholder="2021-01-01")
        
            submit2 = st.button("Predict Single Image", use_container_width=True)


    if submit:
        fig = inference_on_location(model, lat, long, start_date, end_date, segment_interval)
        st.session_state["infered"] = True
        st.session_state["previous_fig"] = fig

    if submit2:
        fig = inference_on_location_and_month(model, lat, long, date)
        st.session_state["infered"] = True
        st.session_state["previous_fig"] = fig

    if st.session_state["infered"]:
        st.plotly_chart(st.session_state["previous_fig"], use_container_width=True)

        
    
if __name__ == "__main__":
    model = init_app("my_train_config.yml")
    app(model)