Spaces:
Runtime error
Runtime error
File size: 5,972 Bytes
9fcd62f 408afad 9fcd62f 709a47d 9fcd62f e414eda 9fcd62f 3006f1e 4debc65 709a47d 9fcd62f 709a47d 86735e0 709a47d 86735e0 9fcd62f 86735e0 709a47d 86735e0 09edb35 86735e0 709a47d 09edb35 86735e0 709a47d 9fcd62f 709a47d 9fcd62f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import streamlit as st
from streamlit_folium import st_folium
import folium
import logging
import sys
import hydra
from plot_functions import *
import hydra
import torch
from model import LitUnsupervisedSegmenter
from helper import inference_on_location_and_month, inference_on_location
DEFAULT_LATITUDE = 48.81
DEFAULT_LONGITUDE = 2.98
DEFAULT_ZOOM = 5
MIN_YEAR = 2018
MAX_YEAR = 2024
FOLIUM_WIDTH = 925
FOLIUM_HEIGHT = 300
st.set_page_config(layout="wide")
@st.cache_resource
def init_cfg(cfg_name):
hydra.initialize(config_path="configs", job_name="corine")
return hydra.compose(config_name=cfg_name)
@st.cache_resource
def init_app(cfg_name) -> LitUnsupervisedSegmenter:
file_handler = logging.FileHandler(filename='biomap.log')
stdout_handler = logging.StreamHandler(stream=sys.stdout)
handlers = [file_handler, stdout_handler]
logging.basicConfig(handlers=handlers, encoding='utf-8', level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
# # Initialize hydra with configs
# GlobalHydra.instance().clear()
cfg = init_cfg(cfg_name)
logging.info(f"config : {cfg}")
nbclasses = cfg.dir_dataset_n_classes
model = LitUnsupervisedSegmenter(nbclasses, cfg)
model = model.cpu()
logging.info(f"Model Initialiazed")
model_path = "biomap/checkpoint/model/model.pt"
saved_state_dict = torch.load(model_path, map_location=torch.device("cpu"))
logging.info(f"Model weights Loaded")
model.load_state_dict(saved_state_dict)
return model
def app(model):
if "infered" not in st.session_state:
st.session_state["infered"] = False
if "submit" not in st.session_state:
st.session_state["submit"] = False
if "submit2" not in st.session_state:
st.session_state["submit2"] = False
st.markdown("<h1 style='text-align: center;'>🐢 Biomap by Ekimetrics 🐢</h1>", unsafe_allow_html=True)
st.markdown("<h2 style='text-align: center;'>Estimate Biodiversity in the world with the help of land cover.</h2>", unsafe_allow_html=True)
st.markdown("<p style='text-align: center;'>The segmentation model is an association of UNet and DinoV1 trained on the dataset CORINE. Land use is divided into 6 differents classes : Each class is assigned a GBS score from 0 to 1</p>", unsafe_allow_html=True)
st.markdown("<p style='text-align: center;'>Buildings : 0.1 | Infrastructure : 0.1 | Cultivation : 0.4 | Wetland : 0.9 | Water : 0.9 | Natural green : 1 </p>", unsafe_allow_html=True)
st.markdown("<p style='text-align: center;'>The score is then averaged on the full image.</p>", unsafe_allow_html=True)
if st.session_state["submit"]:
fig = inference_on_location(model, st.session_state["lat"], st.session_state["long"], st.session_state["start_date"], st.session_state["end_date"], st.session_state["segment_interval"])
st.session_state["infered"] = True
st.session_state["previous_fig"] = fig
if st.session_state["submit2"]:
fig = inference_on_location_and_month(model, st.session_state["lat_2"], st.session_state["long_2"], st.session_state["date_2"])
st.session_state["infered"] = True
st.session_state["previous_fig"] = fig
if st.session_state["infered"]:
st.plotly_chart(st.session_state["previous_fig"], use_container_width=True)
col_1, col_2 = st.columns([0.5, 0.5])
with col_1:
m = folium.Map(location=[DEFAULT_LATITUDE, DEFAULT_LONGITUDE], zoom_start=DEFAULT_ZOOM)
m.add_child(folium.LatLngPopup())
f_map = st_folium(m, width=FOLIUM_WIDTH, height=FOLIUM_HEIGHT)
selected_latitude = DEFAULT_LATITUDE
selected_longitude = DEFAULT_LONGITUDE
if f_map.get("last_clicked"):
selected_latitude = f_map["last_clicked"]["lat"]
selected_longitude = f_map["last_clicked"]["lng"]
with col_2:
tabs1, tabs2 = st.tabs(["TimeLapse", "Single Image"])
with tabs1:
submit = st.button("Predict TimeLapse", use_container_width=True, type="primary")
st.session_state["submit"] = submit
col_tab1_1, col_tab1_2 = st.columns(2)
with col_tab1_1:
lat = st.text_input("latitude", value=selected_latitude)
st.session_state["lat"] = lat
with col_tab1_2:
long = st.text_input("longitude", value=selected_longitude)
st.session_state["long"] = long
col_tab1_11, col_tab1_22 = st.columns(2)
years = list(range(MIN_YEAR, MAX_YEAR, 1))
with col_tab1_11:
start_date = st.selectbox("Start date", years)
st.session_state["start_date"] = start_date
end_years = [year for year in years if year > start_date]
with col_tab1_22:
end_date = st.selectbox("End date", end_years)
st.session_state["end_date"] = end_date
segment_interval = st.radio("Interval of time between two segmentation", options=['month','2months', 'year'],horizontal=True)
st.session_state["segment_interval"] = segment_interval
with tabs2:
submit2 = st.button("Predict Single Image", use_container_width=True, type="primary")
st.session_state["submit2"] = submit2
col_tab2_1, col_tab2_2 = st.columns(2)
with col_tab2_1:
lat_2 = st.text_input("lat.", value=selected_latitude)
st.session_state["lat_2"] = lat_2
with col_tab2_2:
long_2 = st.text_input("long.", value=selected_longitude)
st.session_state["long_2"] = long_2
date_2 = st.text_input("date", "2021-01-01", placeholder="2021-01-01")
st.session_state["date_2"] = date_2
if __name__ == "__main__":
model = init_app("my_train_config.yml")
app(model) |