Spaces:
Running
Running
File size: 1,949 Bytes
f0de4e8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
import torch
from .model import PHNet
import torchvision.transforms.functional as tf
from .util import inference_img, log
from .stylematte import StyleMatte
import numpy as np
class Inference:
def __init__(self, **kwargs):
self.rank = 0
self.__dict__.update(kwargs)
self.model = PHNet(enc_sizes=self.enc_sizes,
skips=self.skips,
grid_count=self.grid_counts,
init_weights=self.init_weights,
init_value=self.init_value)
log(f"checkpoint: {self.checkpoint.harmonizer}")
state = torch.load(self.checkpoint.harmonizer,
map_location=self.device)
self.model.load_state_dict(state, strict=True)
self.model.eval()
def harmonize(self, composite, mask):
if len(composite.shape) < 4:
composite = composite.unsqueeze(0)
while len(mask.shape) < 4:
mask = mask.unsqueeze(0)
composite = tf.resize(composite, [self.image_size, self.image_size])
mask = tf.resize(mask, [self.image_size, self.image_size])
log(composite.shape, mask.shape)
with torch.no_grad():
harmonized = self.model(composite, mask)['harmonized']
result = harmonized * mask + composite * (1-mask)
print(result.shape)
return result
class Matting:
def __init__(self, **kwargs):
self.rank = 0
self.__dict__.update(kwargs)
self.model = StyleMatte().to(self.device)
log(f"checkpoint: {self.checkpoint.matting}")
state = torch.load(self.checkpoint.matting, map_location=self.device)
self.model.load_state_dict(state, strict=True)
self.model.eval()
def extract(self, inp):
mask = inference_img(self.model, inp, self.device)
inp_np = np.array(inp)
fg = mask[:, :, None]*inp_np
return [mask, fg]
|