timroelofs123 commited on
Commit
1931503
·
1 Parent(s): 3c9f70d

scripts test

Browse files
Files changed (2) hide show
  1. scripts/__init__.py +0 -0
  2. scripts/test_functions.py +141 -0
scripts/__init__.py ADDED
File without changes
scripts/test_functions.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import face_recognition
2
+ import numpy as np
3
+ from PIL import Image
4
+ import torch
5
+ from torch.autograd import Variable
6
+ from torchvision import transforms
7
+ from torchvision.io import write_video
8
+ import tempfile
9
+
10
+ mask_file = torch.from_numpy(np.array(Image.open('assets/mask1024.jpg').convert('L'))) / 255
11
+ small_mask_file = torch.from_numpy(np.array(Image.open('assets/mask512.jpg').convert('L'))) / 255
12
+
13
+ def sliding_window_tensor(input_tensor, window_size, stride, your_model, mask=mask_file, small_mask=small_mask_file):
14
+ """
15
+ Apply aging operation on input tensor using a sliding-window method. This operation is done on the GPU, if available.
16
+ """
17
+
18
+ input_tensor = input_tensor.to(next(your_model.parameters()).device)
19
+ mask = mask.to(next(your_model.parameters()).device)
20
+ small_mask = small_mask.to(next(your_model.parameters()).device)
21
+
22
+ n, c, h, w = input_tensor.size()
23
+ output_tensor = torch.zeros((n, 3, h, w), dtype=input_tensor.dtype, device=input_tensor.device)
24
+
25
+ count_tensor = torch.zeros((n, 3, h, w), dtype=torch.float32, device=input_tensor.device)
26
+
27
+ add = 2 if window_size % stride != 0 else 1
28
+
29
+ for y in range(0, h - window_size + add, stride):
30
+ for x in range(0, w - window_size + add, stride):
31
+ window = input_tensor[:, :, y:y + window_size, x:x + window_size]
32
+
33
+ # Apply the same preprocessing as during training
34
+ input_variable = Variable(window, requires_grad=False) # Assuming GPU is available
35
+
36
+ # Forward pass
37
+ with torch.no_grad():
38
+ output = your_model(input_variable)
39
+
40
+ output_tensor[:, :, y:y + window_size, x:x + window_size] += output * small_mask
41
+ count_tensor[:, :, y:y + window_size, x:x + window_size] += small_mask
42
+
43
+ count_tensor = torch.clamp(count_tensor, min=1.0)
44
+
45
+ # Average the overlapping regions
46
+ output_tensor /= count_tensor
47
+
48
+ # Apply mask
49
+ output_tensor *= mask
50
+
51
+ return output_tensor.cpu()
52
+
53
+
54
+ def process_image(your_model, image, video, source_age, target_age=0,
55
+ window_size=512, stride=256, steps=18):
56
+ """
57
+ Aging the person in the image.
58
+ If video=False, we age as from source_age to target_age, and return an image.
59
+ If video=True, we age from source_age to a range of target ages, and return this as the path to a video.
60
+ """
61
+ if video:
62
+ target_age = 0
63
+ input_size = (1024, 1024)
64
+
65
+ # image = face_recognition.load_image_file(filename)
66
+ image = np.array(image)
67
+ if video: # h264 codec requires frame size to be divisible by 2.
68
+ width, height, depth = image.shape
69
+ new_width = width if width % 2 == 0 else width - 1
70
+ new_height = height if height % 2 == 0 else height - 1
71
+ image.resize((new_width, new_height, depth))
72
+
73
+ fl = face_recognition.face_locations(image)[0]
74
+
75
+ # calculate margins
76
+ margin_y_t = int((fl[2] - fl[0]) * .63 * .85) # larger as the forehead is often cut off
77
+ margin_y_b = int((fl[2] - fl[0]) * .37 * .85)
78
+ margin_x = int((fl[1] - fl[3]) // (2 / .85))
79
+ margin_y_t += 2 * margin_x - margin_y_t - margin_y_b # make sure square is preserved
80
+
81
+ l_y = max([fl[0] - margin_y_t, 0])
82
+ r_y = min([fl[2] + margin_y_b, image.shape[0]])
83
+ l_x = max([fl[3] - margin_x, 0])
84
+ r_x = min([fl[1] + margin_x, image.shape[1]])
85
+
86
+ # crop image
87
+ cropped_image = image[l_y:r_y, l_x:r_x, :]
88
+
89
+ # Resizing
90
+ orig_size = cropped_image.shape[:2]
91
+
92
+ cropped_image = transforms.ToTensor()(cropped_image)
93
+
94
+ cropped_image_resized = transforms.Resize(input_size, interpolation=Image.BILINEAR, antialias=True)(cropped_image)
95
+
96
+ source_age_channel = torch.full_like(cropped_image_resized[:1, :, :], source_age / 100)
97
+ target_age_channel = torch.full_like(cropped_image_resized[:1, :, :], target_age / 100)
98
+ input_tensor = torch.cat([cropped_image_resized, source_age_channel, target_age_channel], dim=0).unsqueeze(0)
99
+
100
+ image = transforms.ToTensor()(image)
101
+
102
+ if video:
103
+ # aging in steps
104
+ interval = .8 / steps
105
+ aged_cropped_images = torch.zeros((steps, 3, input_size[1], input_size[0]))
106
+ for i in range(0, steps):
107
+ input_tensor[:, -1, :, :] += interval
108
+
109
+ # performing actions on image
110
+ aged_cropped_images[i, ...] = sliding_window_tensor(input_tensor, window_size, stride, your_model)
111
+
112
+ # resize back to original size
113
+ aged_cropped_images_resized = transforms.Resize(orig_size, interpolation=Image.BILINEAR, antialias=True)(
114
+ aged_cropped_images)
115
+
116
+ # re-apply
117
+ image = image.repeat(steps, 1, 1, 1)
118
+
119
+ image[:, :, l_y:r_y, l_x:r_x] += aged_cropped_images_resized
120
+ image = torch.clamp(image, 0, 1)
121
+ image = (image * 255).to(torch.uint8)
122
+
123
+ output_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
124
+
125
+ write_video(output_file.name, image.permute(0, 2, 3, 1), 2)
126
+
127
+ return output_file.name
128
+
129
+ else:
130
+ # performing actions on image
131
+ aged_cropped_image = sliding_window_tensor(input_tensor, window_size, stride, your_model)
132
+
133
+ # resize back to original size
134
+ aged_cropped_image_resized = transforms.Resize(orig_size, interpolation=Image.BILINEAR, antialias=True)(
135
+ aged_cropped_image)
136
+
137
+ # re-apply
138
+ image[:, l_y:r_y, l_x:r_x] += aged_cropped_image_resized.squeeze(0)
139
+ image = torch.clamp(image, 0, 1)
140
+
141
+ return transforms.functional.to_pil_image(image)