File size: 2,334 Bytes
5c718d1
 
 
9fcd62f
5c718d1
5dd3935
5c718d1
 
 
 
 
 
 
 
 
4debc65
5dd3935
9fcd62f
5dd3935
5c718d1
 
5dd3935
5c718d1
 
5dd3935
 
 
 
4debc65
 
 
 
5dd3935
5c718d1
 
 
 
 
 
4debc65
5c718d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4debc65
 
5c718d1
 
 
 
 
 
 
 
 
5dd3935
5c718d1
5dd3935
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch.multiprocessing
import torchvision.transforms as T
from utils import transform_to_pil
import logging

preprocess = T.Compose(
        [
            T.ToPILImage(),
            T.Resize((320, 320)),
            #    T.CenterCrop(224),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )

import numpy as np
def inference(images, model):
    logging.info("Inference on Images")
    x = torch.stack([preprocess(image) for image in images]).cpu()

    with torch.no_grad():
        _, code = model.net(x)
        linear_pred = model.linear_probe(x, code)
        linear_pred = linear_pred.argmax(1)
        outputs = [{
            "img": x[i].detach().cpu(),
            "linear_preds": linear_pred[i].detach().cpu(),
        } for i in range(x.shape[0])]
    
    # water to natural green
    for output in outputs:
        output["linear_preds"] = torch.where(output["linear_preds"] == 5,  3, output["linear_preds"])
    return outputs


if __name__ == "__main__":
    import hydra
    from model import LitUnsupervisedSegmenter
    from utils_gee import extract_img, transform_ee_img
    import os
    latitude = 2.98
    longitude = 48.81
    start_date = '2020-03-20'
    end_date = '2020-04-20'

    location = [float(latitude), float(longitude)]
    # Extract img numpy from earth engine and transform it to PIL img
    img = extract_img(location, start_date, end_date)
    image = transform_ee_img(
        img, max=0.3
    )  # max value is the value from numpy file that will be equal to 255
    print("image loaded")
    # Initialize hydra with configs
    hydra.initialize(config_path="configs", job_name="corine")
    cfg = hydra.compose(config_name="my_train_config.yml")

    # Load the model
    
    model_path = os.path.join(os.path.dirname(__file__), "checkpoint/model/model.pt")
    saved_state_dict = torch.load(model_path, map_location=torch.device("cpu"))

    nbclasses = cfg.dir_dataset_n_classes

    model = LitUnsupervisedSegmenter(nbclasses, cfg)
    print("model initialized")
    model.load_state_dict(saved_state_dict)
    print("model loaded")
    # img.save("output/image.png")
    inference([image], model)

    inference([image,image], model)