Spaces:
Runtime error
Runtime error
File size: 5,964 Bytes
5c718d1 9fcd62f 5dd3935 5c718d1 9fcd62f 5c718d1 5dd3935 5c718d1 5dd3935 9fcd62f 5c718d1 9fcd62f 5c718d1 5dd3935 5c718d1 5dd3935 5c718d1 9fcd62f 5dd3935 5c718d1 5dd3935 5c718d1 5dd3935 851dbaf 5dd3935 9fcd62f 5c718d1 5dd3935 5c718d1 5dd3935 851dbaf 9fcd62f 5c718d1 5dd3935 5c718d1 5dd3935 5c718d1 5dd3935 5c718d1 5dd3935 5c718d1 5dd3935 5c718d1 5dd3935 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import torch.multiprocessing
import torchvision.transforms as T
import numpy as np
from utils import transform_to_pil, compute_biodiv_score, plot_imgs_labels, plot_image
from utils_gee import get_image
from dateutil.relativedelta import relativedelta
from model import LitUnsupervisedSegmenter
import datetime
import matplotlib as mpl
from joblib import Parallel, cpu_count, delayed
import logging
from inference import inference
import streamlit as st
import cv2
@st.cache_data(hash_funcs={LitUnsupervisedSegmenter: lambda dt: dt.name})
def inference_on_location(model, longitude=2.98, latitude=48.81, start_date=2020, end_date=2022, how="year"):
"""Performe an inference on the latitude and longitude between the start date and the end date
Args:
latitude (float): the latitude of the landscape
longitude (float): the longitude of the landscape
start_date (str): the start date for our inference
end_date (str): the end date for our inference
model (_type_, optional): _description_. Defaults to model.
Returns:
img, labeled_img,biodiv_score: the original landscape, the labeled landscape and the biodiversity score and the landscape
"""
logging.info("Running Inference on location")
logging.info(f"latitude : {latitude} & longitude : {longitude}")
logging.info(f"start date : {start_date} & end_date : {end_date}")
logging.info(f"Prediction on intervale : {how}")
if how == "month":
delta_month = 1
elif how == "2months":
delta_month = 2
elif how == "year":
delta_month = 11
else:
raise ValueError("Wrong interval")
assert int(end_date) > int(start_date), "end date must be stricly higher than start date"
location = [float(latitude), float(longitude)]
# Extract img numpy from earth engine and transform it to PIL img
dates = [datetime.datetime(start_date, 1, 1, 0, 0, 0)]
while dates[-1] < datetime.datetime(int(end_date), 1, 1, 0, 0, 0):
dates.append(dates[-1] + relativedelta(months=delta_month))
dates = [d.strftime("%Y-%m-%d") for d in dates]
all_image = Parallel(n_jobs=cpu_count(), prefer="threads")(delayed(get_image)(location, d1,d2) for d1, d2 in zip(dates[:-1],dates[1:]))
# all_image = [cv2.imread("output/img.png") for i in range(len(dates))]
outputs = inference(np.array(all_image), model)
logging.info("Calculating Biodiversity Scores...")
scores, scores_details = map(list, zip(*[compute_biodiv_score(output["linear_preds"].detach().numpy()) for output in outputs]))
logging.info(f"Calculated Biodiversity Score : {scores}")
imgs, labels, labeled_imgs = map(list, zip(*[transform_to_pil(output) for output in outputs]))
images = [np.asarray(img) for img in imgs]
labeled_imgs = [np.asarray(img) for img in labeled_imgs]
title=f"TimeLapse at location {tuple(location)} between {start_date} and {end_date}"
fig = plot_imgs_labels(dates, images, labeled_imgs, scores_details, scores, title=title)
# fig.save("test.png")
return fig
@st.cache_data(hash_funcs={LitUnsupervisedSegmenter: lambda dt: dt.name})
def inference_on_location_and_month(model, longitude = 2.98, latitude = 48.81, start_date = '2020-03-20'):
"""Performe an inference on the latitude and longitude between the start date and the end date
Args:
latitude (float): the latitude of the landscape
longitude (float): the longitude of the landscape
start_date (str): the start date for our inference
end_date (str): the end date for our inference
model (_type_, optional): _description_. Defaults to model.
Returns:
img, labeled_img,biodiv_score: the original landscape, the labeled landscape and the biodiversity score and the landscape
"""
logging.info("Running Inference on location and month")
logging.info(f"latitude : {latitude} & longitude : {longitude}")
location = [float(latitude), float(longitude)]
# Extract img numpy from earth engine and transform it to PIL img
end_date = datetime.datetime.strptime(start_date, "%Y-%m-%d") + relativedelta(months=1)
end_date = datetime.datetime.strftime(end_date, "%Y-%m-%d")
img_test = get_image(location, start_date, end_date)
outputs = inference(np.array([img_test]), model)
logging.info("Calculating Biodiversity Score...")
score, score_details = compute_biodiv_score(outputs[0]["linear_preds"].detach().numpy())
logging.info(f"Calculated Biodiversity Score : {score}")
img, label, labeled_img = transform_to_pil(outputs[0])
title=f"Prediction at location {tuple(location)} at {start_date}"
fig = plot_image([start_date], [np.asarray(img)], [np.asarray(labeled_img)], [score_details], [score],title=title)
return fig
if __name__ == "__main__":
import logging
import hydra
import sys
from model import LitUnsupervisedSegmenter
file_handler = logging.FileHandler(filename='biomap.log')
stdout_handler = logging.StreamHandler(stream=sys.stdout)
handlers = [file_handler, stdout_handler]
logging.basicConfig(handlers=handlers, encoding='utf-8', level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
# Initialize hydra with configs
hydra.initialize(config_path="configs", job_name="corine")
cfg = hydra.compose(config_name="my_train_config.yml")
logging.info(f"config : {cfg}")
# Load the model
nbclasses = cfg.dir_dataset_n_classes
model = LitUnsupervisedSegmenter(nbclasses, cfg)
logging.info(f"Model Initialiazed")
model_path = "biomap/checkpoint/model/model.pt"
saved_state_dict = torch.load(model_path, map_location=torch.device("cpu"))
logging.info(f"Model weights Loaded")
model.load_state_dict(saved_state_dict)
logging.info(f"Model Loaded")
# inference_on_location_and_month(model)
inference_on_location(model)
|