File size: 5,481 Bytes
1931503
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import face_recognition
import numpy as np
from PIL import Image
import torch
from torch.autograd import Variable
from torchvision import transforms
from torchvision.io import write_video
import tempfile

mask_file = torch.from_numpy(np.array(Image.open('assets/mask1024.jpg').convert('L'))) / 255
small_mask_file = torch.from_numpy(np.array(Image.open('assets/mask512.jpg').convert('L'))) / 255

def sliding_window_tensor(input_tensor, window_size, stride, your_model, mask=mask_file, small_mask=small_mask_file):
    """
    Apply aging operation on input tensor using a sliding-window method. This operation is done on the GPU, if available.
    """

    input_tensor = input_tensor.to(next(your_model.parameters()).device)
    mask = mask.to(next(your_model.parameters()).device)
    small_mask = small_mask.to(next(your_model.parameters()).device)

    n, c, h, w = input_tensor.size()
    output_tensor = torch.zeros((n, 3, h, w), dtype=input_tensor.dtype, device=input_tensor.device)

    count_tensor = torch.zeros((n, 3, h, w), dtype=torch.float32, device=input_tensor.device)

    add = 2 if window_size % stride != 0 else 1

    for y in range(0, h - window_size + add, stride):
        for x in range(0, w - window_size + add, stride):
            window = input_tensor[:, :, y:y + window_size, x:x + window_size]

            # Apply the same preprocessing as during training
            input_variable = Variable(window, requires_grad=False)  # Assuming GPU is available

            # Forward pass
            with torch.no_grad():
                output = your_model(input_variable)

            output_tensor[:, :, y:y + window_size, x:x + window_size] += output * small_mask
            count_tensor[:, :, y:y + window_size, x:x + window_size] += small_mask

    count_tensor = torch.clamp(count_tensor, min=1.0)

    # Average the overlapping regions
    output_tensor /= count_tensor

    # Apply mask
    output_tensor *= mask

    return output_tensor.cpu()


def process_image(your_model, image, video, source_age, target_age=0,
                  window_size=512, stride=256, steps=18):
    """
    Aging the person in the image.
    If video=False, we age as from source_age to target_age, and return an image.
    If video=True, we age from source_age to a range of target ages, and return this as the path to a video.
    """
    if video:
        target_age = 0
    input_size = (1024, 1024)

    # image = face_recognition.load_image_file(filename)
    image = np.array(image)
    if video:  # h264 codec requires frame size to be divisible by 2.
        width, height, depth = image.shape
        new_width = width if width % 2 == 0 else width - 1
        new_height = height if height % 2 == 0 else height - 1
        image.resize((new_width, new_height, depth))

    fl = face_recognition.face_locations(image)[0]

    # calculate margins
    margin_y_t = int((fl[2] - fl[0]) * .63 * .85)  # larger as the forehead is often cut off
    margin_y_b = int((fl[2] - fl[0]) * .37 * .85)
    margin_x = int((fl[1] - fl[3]) // (2 / .85))
    margin_y_t += 2 * margin_x - margin_y_t - margin_y_b  # make sure square is preserved

    l_y = max([fl[0] - margin_y_t, 0])
    r_y = min([fl[2] + margin_y_b, image.shape[0]])
    l_x = max([fl[3] - margin_x, 0])
    r_x = min([fl[1] + margin_x, image.shape[1]])

    # crop image
    cropped_image = image[l_y:r_y, l_x:r_x, :]

    # Resizing
    orig_size = cropped_image.shape[:2]

    cropped_image = transforms.ToTensor()(cropped_image)

    cropped_image_resized = transforms.Resize(input_size, interpolation=Image.BILINEAR, antialias=True)(cropped_image)

    source_age_channel = torch.full_like(cropped_image_resized[:1, :, :], source_age / 100)
    target_age_channel = torch.full_like(cropped_image_resized[:1, :, :], target_age / 100)
    input_tensor = torch.cat([cropped_image_resized, source_age_channel, target_age_channel], dim=0).unsqueeze(0)

    image = transforms.ToTensor()(image)

    if video:
        # aging in steps
        interval = .8 / steps
        aged_cropped_images = torch.zeros((steps, 3, input_size[1], input_size[0]))
        for i in range(0, steps):
            input_tensor[:, -1, :, :] += interval

            # performing actions on image
            aged_cropped_images[i, ...] = sliding_window_tensor(input_tensor, window_size, stride, your_model)

        # resize back to original size
        aged_cropped_images_resized = transforms.Resize(orig_size, interpolation=Image.BILINEAR, antialias=True)(
            aged_cropped_images)

        # re-apply
        image = image.repeat(steps, 1, 1, 1)

        image[:, :, l_y:r_y, l_x:r_x] += aged_cropped_images_resized
        image = torch.clamp(image, 0, 1)
        image = (image * 255).to(torch.uint8)

        output_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)

        write_video(output_file.name, image.permute(0, 2, 3, 1), 2)

        return output_file.name

    else:
        # performing actions on image
        aged_cropped_image = sliding_window_tensor(input_tensor, window_size, stride, your_model)

        # resize back to original size
        aged_cropped_image_resized = transforms.Resize(orig_size, interpolation=Image.BILINEAR, antialias=True)(
            aged_cropped_image)

        # re-apply
        image[:, l_y:r_y, l_x:r_x] += aged_cropped_image_resized.squeeze(0)
        image = torch.clamp(image, 0, 1)

        return transforms.functional.to_pil_image(image)