YinuoGuo27 commited on
Commit
7ff2c71
·
verified ·
1 Parent(s): 6635078

Upload croper.py

Browse files
Files changed (1) hide show
  1. difpoint/src/croper.py +299 -0
difpoint/src/croper.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import time
4
+ import glob
5
+ import argparse
6
+ import scipy
7
+ import numpy as np
8
+ from PIL import Image
9
+ from tqdm import tqdm
10
+ from itertools import cycle
11
+
12
+ from torch.multiprocessing import Pool, Process, set_start_method
13
+
14
+
15
+ """
16
+ brief: face alignment with FFHQ method (https://github.com/NVlabs/ffhq-dataset)
17
+ author: lzhbrian (https://lzhbrian.me)
18
+ date: 2020.1.5
19
+ note: code is heavily borrowed from
20
+ https://github.com/NVlabs/ffhq-dataset
21
+ http://dlib.net/face_landmark_detection.py.html
22
+ requirements:
23
+ apt install cmake
24
+ conda install Pillow numpy scipy
25
+ pip install dlib
26
+ # download face landmark model from:
27
+ # http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
28
+ """
29
+
30
+ import numpy as np
31
+ from PIL import Image
32
+ import dlib
33
+
34
+
35
+ class Croper:
36
+ def __init__(self, path_of_lm):
37
+ # download model from: http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
38
+ self.predictor = dlib.shape_predictor(path_of_lm)
39
+
40
+ def get_landmark(self, img_np):
41
+ """get landmark with dlib
42
+ :return: np.array shape=(68, 2)
43
+ """
44
+ detector = dlib.get_frontal_face_detector()
45
+ dets = detector(img_np, 1)
46
+ # print("Number of faces detected: {}".format(len(dets)))
47
+ # for k, d in enumerate(dets):
48
+ if len(dets) == 0:
49
+ return None
50
+ d = dets[0]
51
+ # Get the landmarks/parts for the face in box d.
52
+ shape = self.predictor(img_np, d)
53
+ # print("Part 0: {}, Part 1: {} ...".format(shape.part(0), shape.part(1)))
54
+ t = list(shape.parts())
55
+ a = []
56
+ for tt in t:
57
+ a.append([tt.x, tt.y])
58
+ lm = np.array(a)
59
+ # lm is a shape=(68,2) np.array
60
+ return lm
61
+
62
+ def align_face(self, img, lm, output_size=1024):
63
+ """
64
+ :param filepath: str
65
+ :return: PIL Image
66
+ """
67
+ lm_chin = lm[0: 17] # left-right
68
+ lm_eyebrow_left = lm[17: 22] # left-right
69
+ lm_eyebrow_right = lm[22: 27] # left-right
70
+ lm_nose = lm[27: 31] # top-down
71
+ lm_nostrils = lm[31: 36] # top-down
72
+ lm_eye_left = lm[36: 42] # left-clockwise
73
+ lm_eye_right = lm[42: 48] # left-clockwise
74
+ lm_mouth_outer = lm[48: 60] # left-clockwise
75
+ lm_mouth_inner = lm[60: 68] # left-clockwise
76
+
77
+ # Calculate auxiliary vectors.
78
+ eye_left = np.mean(lm_eye_left, axis=0)
79
+ eye_right = np.mean(lm_eye_right, axis=0)
80
+ eye_avg = (eye_left + eye_right) * 0.5
81
+ eye_to_eye = eye_right - eye_left
82
+ mouth_left = lm_mouth_outer[0]
83
+ mouth_right = lm_mouth_outer[6]
84
+ mouth_avg = (mouth_left + mouth_right) * 0.5
85
+ eye_to_mouth = mouth_avg - eye_avg
86
+
87
+ # Choose oriented crop rectangle.
88
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] # Addition of binocular difference and double mouth difference
89
+ x /= np.hypot(*x) # hypot函数计算直角三角形的斜边长,用斜边长对三角形两条直边做归一化
90
+ x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) # 双眼差和眼嘴差,选较大的作为基准尺度
91
+ y = np.flipud(x) * [-1, 1]
92
+ c = eye_avg + eye_to_mouth * 0.1
93
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) # 定义四边形,以面部基准位置为中心上下左右平移得到四个顶点
94
+ qsize = np.hypot(*x) * 2 # 定义四边形的大小(边长),为基准尺度的2倍
95
+
96
+ # Shrink.
97
+ # 如果计算出的四边形太大了,就按比例缩小它
98
+ shrink = int(np.floor(qsize / output_size * 0.5))
99
+ if shrink > 1:
100
+ rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
101
+ img = img.resize(rsize, Image.ANTIALIAS)
102
+ quad /= shrink
103
+ qsize /= shrink
104
+ else:
105
+ rsize = (int(np.rint(float(img.size[0]))), int(np.rint(float(img.size[1]))))
106
+
107
+ # Crop.
108
+ border = max(int(np.rint(qsize * 0.1)), 3)
109
+ crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
110
+ int(np.ceil(max(quad[:, 1]))))
111
+ crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]),
112
+ min(crop[3] + border, img.size[1]))
113
+ if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
114
+ # img = img.crop(crop)
115
+ quad -= crop[0:2]
116
+
117
+ # Pad.
118
+ pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
119
+ int(np.ceil(max(quad[:, 1]))))
120
+ pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0),
121
+ max(pad[3] - img.size[1] + border, 0))
122
+ # if enable_padding and max(pad) > border - 4:
123
+ # pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
124
+ # img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
125
+ # h, w, _ = img.shape
126
+ # y, x, _ = np.ogrid[:h, :w, :1]
127
+ # mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]),
128
+ # 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3]))
129
+ # blur = qsize * 0.02
130
+ # img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
131
+ # img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
132
+ # img = Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
133
+ # quad += pad[:2]
134
+
135
+ # Transform.
136
+ quad = (quad + 0.5).flatten()
137
+ lx = max(min(quad[0], quad[2]), 0)
138
+ ly = max(min(quad[1], quad[7]), 0)
139
+ rx = min(max(quad[4], quad[6]), img.size[0])
140
+ ry = min(max(quad[3], quad[5]), img.size[0])
141
+ # img = img.transform((transform_size, transform_size), Image.QUAD, (quad + 0.5).flatten(),
142
+ # Image.BILINEAR)
143
+ # if output_size < transform_size:
144
+ # img = img.resize((output_size, output_size), Image.ANTIALIAS)
145
+
146
+ # Save aligned image.
147
+ return rsize, crop, [lx, ly, rx, ry]
148
+
149
+ # def crop(self, img_np_list):
150
+ # for _i in range(len(img_np_list)):
151
+ # img_np = img_np_list[_i]
152
+ # lm = self.get_landmark(img_np)
153
+ # if lm is None:
154
+ # return None
155
+ # crop, quad = self.align_face(img=Image.fromarray(img_np), lm=lm, output_size=512)
156
+ # clx, cly, crx, cry = crop
157
+ # lx, ly, rx, ry = quad
158
+ # lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)
159
+
160
+ # _inp = img_np_list[_i]
161
+ # _inp = _inp[cly:cry, clx:crx]
162
+ # _inp = _inp[ly:ry, lx:rx]
163
+ # img_np_list[_i] = _inp
164
+ # return img_np_list
165
+
166
+ def crop(self, img_np_list, still=False, xsize=512): # first frame for all video
167
+ img_np = img_np_list[0]
168
+ lm = self.get_landmark(img_np)
169
+ if lm is None:
170
+ raise 'can not detect the landmark from source image'
171
+ rsize, crop, quad = self.align_face(img=Image.fromarray(img_np), lm=lm, output_size=xsize)
172
+ clx, cly, crx, cry = crop
173
+ lx, ly, rx, ry = quad
174
+ lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)
175
+ for _i in range(len(img_np_list)):
176
+ _inp = img_np_list[_i]
177
+ _inp = cv2.resize(_inp, (rsize[0], rsize[1]))
178
+ _inp = _inp[cly:cry, clx:crx]
179
+ # cv2.imwrite('test1.jpg', _inp)
180
+ if not still:
181
+ _inp = _inp[ly:ry, lx:rx]
182
+ # cv2.imwrite('test2.jpg', _inp)
183
+ img_np_list[_i] = _inp
184
+ return img_np_list, crop, quad
185
+
186
+
187
+ def read_video(filename, uplimit=100):
188
+ frames = []
189
+ cap = cv2.VideoCapture(filename)
190
+ cnt = 0
191
+ while cap.isOpened():
192
+ ret, frame = cap.read()
193
+ if ret:
194
+ frame = cv2.resize(frame, (512, 512))
195
+ frames.append(frame)
196
+ else:
197
+ break
198
+ cnt += 1
199
+ if cnt >= uplimit:
200
+ break
201
+ cap.release()
202
+ assert len(frames) > 0, f'{filename}: video with no frames!'
203
+ return frames
204
+
205
+
206
+ def create_video(video_name, frames, fps=25, video_format='.mp4', resize_ratio=1):
207
+ # video_name = os.path.dirname(image_folder) + video_format
208
+ # img_list = glob.glob1(image_folder, 'frame*')
209
+ # img_list.sort()
210
+ # frame = cv2.imread(os.path.join(image_folder, img_list[0]))
211
+ # frame = cv2.resize(frame, (0, 0), fx=resize_ratio, fy=resize_ratio)
212
+ # height, width, layers = frames[0].shape
213
+ height, width, layers = 512, 512, 3
214
+ if video_format == '.mp4':
215
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
216
+ elif video_format == '.avi':
217
+ fourcc = cv2.VideoWriter_fourcc(*'XVID')
218
+ video = cv2.VideoWriter(video_name, fourcc, fps, (width, height))
219
+ for _frame in frames:
220
+ _frame = cv2.resize(_frame, (height, width), interpolation=cv2.INTER_LINEAR)
221
+ video.write(_frame)
222
+
223
+ def create_images(video_name, frames):
224
+ height, width, layers = 512, 512, 3
225
+ images_dir = video_name.split('.')[0]
226
+ os.makedirs(images_dir, exist_ok=True)
227
+ for i, _frame in enumerate(frames):
228
+ _frame = cv2.resize(_frame, (height, width), interpolation=cv2.INTER_LINEAR)
229
+ _frame_path = os.path.join(images_dir, str(i)+'.jpg')
230
+ cv2.imwrite(_frame_path, _frame)
231
+
232
+ def run(data):
233
+ filename, opt, device = data
234
+ os.environ['CUDA_VISIBLE_DEVICES'] = device
235
+ croper = Croper()
236
+
237
+ frames = read_video(filename, uplimit=opt.uplimit)
238
+ name = filename.split('/')[-1] # .split('.')[0]
239
+ name = os.path.join(opt.output_dir, name)
240
+
241
+ frames = croper.crop(frames)
242
+ if frames is None:
243
+ print(f'{name}: detect no face. should removed')
244
+ return
245
+ # create_video(name, frames)
246
+ create_images(name, frames)
247
+
248
+
249
+ def get_data_path(video_dir):
250
+ eg_video_files = ['/apdcephfs/share_1290939/quincheng/datasets/HDTF/backup_fps25/WDA_KatieHill_000.mp4']
251
+ # filenames = list()
252
+ # VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
253
+ # VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
254
+ # extensions = VIDEO_EXTENSIONS
255
+ # for ext in extensions:
256
+ # filenames = sorted(glob.glob(f'{opt.input_dir}/**/*.{ext}'))
257
+ # print('Total number of videos:', len(filenames))
258
+ return eg_video_files
259
+
260
+
261
+ def get_wra_data_path(video_dir):
262
+ if opt.option == 'video':
263
+ videos_path = sorted(glob.glob(f'{video_dir}/*.mp4'))
264
+ elif opt.option == 'image':
265
+ videos_path = sorted(glob.glob(f'{video_dir}/*/'))
266
+ else:
267
+ raise NotImplementedError
268
+ print('Example videos: ', videos_path[:2])
269
+ return videos_path
270
+
271
+
272
+ if __name__ == '__main__':
273
+ set_start_method('spawn')
274
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
275
+ parser.add_argument('--input_dir', type=str, help='the folder of the input files')
276
+ parser.add_argument('--output_dir', type=str, help='the folder of the output files')
277
+ parser.add_argument('--device_ids', type=str, default='0,1')
278
+ parser.add_argument('--workers', type=int, default=8)
279
+ parser.add_argument('--uplimit', type=int, default=500)
280
+ parser.add_argument('--option', type=str, default='video')
281
+
282
+ root = '/apdcephfs/share_1290939/quincheng/datasets/HDTF'
283
+ cmd = f'--input_dir {root}/backup_fps25_first20s_sync/ ' \
284
+ f'--output_dir {root}/crop512_stylegan_firstframe_sync/ ' \
285
+ '--device_ids 0 ' \
286
+ '--workers 8 ' \
287
+ '--option video ' \
288
+ '--uplimit 500 '
289
+ opt = parser.parse_args(cmd.split())
290
+ # filenames = get_data_path(opt.input_dir)
291
+ filenames = get_wra_data_path(opt.input_dir)
292
+ os.makedirs(opt.output_dir, exist_ok=True)
293
+ print(f'Video numbers: {len(filenames)}')
294
+ pool = Pool(opt.workers)
295
+ args_list = cycle([opt])
296
+ device_ids = opt.device_ids.split(",")
297
+ device_ids = cycle(device_ids)
298
+ for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
299
+ None