File size: 1,949 Bytes
f0de4e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
from .model import PHNet
import torchvision.transforms.functional as tf
from .util import inference_img, log
from .stylematte import StyleMatte
import numpy as np


class Inference:
    def __init__(self, **kwargs):
        self.rank = 0
        self.__dict__.update(kwargs)
        self.model = PHNet(enc_sizes=self.enc_sizes,
                           skips=self.skips,
                           grid_count=self.grid_counts,
                           init_weights=self.init_weights,
                           init_value=self.init_value)
        log(f"checkpoint: {self.checkpoint.harmonizer}")
        state = torch.load(self.checkpoint.harmonizer,
                           map_location=self.device)

        self.model.load_state_dict(state, strict=True)
        self.model.eval()

    def harmonize(self, composite, mask):
        if len(composite.shape) < 4:
            composite = composite.unsqueeze(0)
        while len(mask.shape) < 4:
            mask = mask.unsqueeze(0)
        composite = tf.resize(composite, [self.image_size, self.image_size])
        mask = tf.resize(mask, [self.image_size, self.image_size])
        log(composite.shape, mask.shape)
        with torch.no_grad():
            harmonized = self.model(composite, mask)['harmonized']

        result = harmonized * mask + composite * (1-mask)
        print(result.shape)
        return result


class Matting:
    def __init__(self, **kwargs):
        self.rank = 0
        self.__dict__.update(kwargs)
        self.model = StyleMatte().to(self.device)
        log(f"checkpoint: {self.checkpoint.matting}")
        state = torch.load(self.checkpoint.matting, map_location=self.device)
        self.model.load_state_dict(state, strict=True)
        self.model.eval()

    def extract(self, inp):
        mask = inference_img(self.model, inp, self.device)
        inp_np = np.array(inp)
        fg = mask[:, :, None]*inp_np

        return [mask, fg]