File size: 2,270 Bytes
70197b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c184f5
 
 
 
 
 
 
70197b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
from torchvision import transforms
from huggingface_hub import hf_hub_download
import json
import io
import base64
from PIL import Image
from omegaconf import OmegaConf

from model import Generator


class EndpointHandler:

    def __init__(self, path=''):
        self.transform = transforms.Compose(
            [
                transforms.ToImage(),
                transforms.ToDtype(torch.float32, scale=True),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
            ]
        )

        repo_id = "Kiwinicki/sat2map-generator"
        generator_path = hf_hub_download(repo_id=repo_id, filename="generator.pth")
        config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
        model_path = hf_hub_download(repo_id=repo_id, filename="model.py")

        with open(config_path, "r") as f:
            config_dict = json.load(f)
        cfg = OmegaConf.create(config_dict)

        self.generator = Generator(cfg)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.generator.load_state_dict(torch.load(generator_path, map_location=self.device))
        self.generator.eval()


    def __call__(self, data: dict[str, any]) -> dict[str, str]:
        base64_image = data.get('inputs')
        input_tensor = self._decode_base64_image(base64_image)
        # print('Input tensor shape: ' + str(input_tensor.shape))
        output_tensor = self.generator(input_tensor.to(self.device))
        output_tensor = output_tensor.squeeze(0)
        output_image = transforms.ToPILImage()(output_tensor)
        output_image = output_image.convert('RGB')
        output_buffer = io.BytesIO()
        output_image.save(output_buffer, format="png")
        base64_output = base64.b64encode(output_buffer.getvalue()).decode('utf-8')
        return {"output": base64_output}


    def _decode_base64_image(self, base64_image: str) -> torch.Tensor:
        image_decoded = base64.b64decode(base64_image)
        image = Image.open(io.BytesIO(image_decoded)).convert('RGB')
        image_tensor: torch.Tensor = self.transform(image)
        image_tensor = image_tensor.unsqueeze(0)
        return image_tensor