svjack commited on
Commit
25ebfd3
·
verified ·
1 Parent(s): 8959c4b

Upload video_to_sketch_script_cv2.py

Browse files
Files changed (1) hide show
  1. video_to_sketch_script_cv2.py +122 -0
video_to_sketch_script_cv2.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ python video_to_sketch_script_cv2.py test test-sketch --copy_others
3
+ '''
4
+
5
+ import gc
6
+ import os
7
+ import shutil
8
+ import argparse
9
+ import numpy as np
10
+ import torch
11
+ import cv2
12
+ from huggingface_hub import hf_hub_download
13
+ from PIL.Image import Resampling
14
+ from pytorchvideo.data.encoded_video import EncodedVideo
15
+ from pytorchvideo.transforms.functional import uniform_temporal_subsample
16
+ from torchvision.transforms.functional import resize
17
+ from tqdm import tqdm
18
+
19
+ from modeling import Generator
20
+
21
+ MAX_DURATION = 60
22
+ OUT_FPS = 18
23
+ DEVICE = "cpu" if not torch.cuda.is_available() else "cuda"
24
+
25
+ # Load the model
26
+ model = Generator(3, 1, 3)
27
+ weights_path = hf_hub_download("nateraw/image-2-line-drawing", "pytorch_model.bin")
28
+ model.load_state_dict(torch.load(weights_path, map_location=DEVICE))
29
+ model.eval()
30
+
31
+ def process_one_second(vid, start_sec, out_fps):
32
+ """Process one second of a video at a given fps
33
+ Args:
34
+ vid (_type_): A pytorchvideo.EncodedVideo instance containing the video to process
35
+ start_sec (_type_): The second to start processing at
36
+ out_fps (_type_): The fps to output the video at
37
+ Returns:
38
+ np.array: The processed video as a numpy array with shape (T, H, W, C)
39
+ """
40
+ # C, T, H, W
41
+ video_arr = vid.get_clip(start_sec, start_sec + 1)["video"]
42
+ # C, T, H, W where T == frames per second
43
+ x = uniform_temporal_subsample(video_arr, out_fps)
44
+ # C, T, H, W where H has been scaled to 256
45
+ x = resize(x, 256, Resampling.BICUBIC)
46
+ # C, T, H, W -> T, C, H, W (basically T acts as batch size now)
47
+ x = x.permute(1, 0, 2, 3)
48
+
49
+ with torch.no_grad():
50
+ # T, 1, H, W
51
+ out = model(x)
52
+
53
+ # T, C, H, W -> T, H, W, C Rescaled to 0-255
54
+ out = out.permute(0, 2, 3, 1).clip(0, 1) * 255
55
+ # Greyscale -> RGB
56
+ out = out.repeat(1, 1, 1, 3)
57
+ return out.cpu().numpy().astype(np.uint8) # Convert to uint8 for OpenCV
58
+
59
+ def process_video(input_video_path, output_video_path):
60
+ vid = EncodedVideo.from_path(input_video_path)
61
+ duration = min(MAX_DURATION, int(vid.duration))
62
+
63
+ # Initialize VideoWriter (get frame size from first frame)
64
+ first_frame = process_one_second(vid, start_sec=0, out_fps=OUT_FPS)
65
+ height, width = first_frame.shape[1], first_frame.shape[2]
66
+
67
+ # Use 'mp4v' for .mp4, 'XVID' for .avi, etc.
68
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
69
+ out_video = cv2.VideoWriter(
70
+ output_video_path,
71
+ fourcc,
72
+ OUT_FPS,
73
+ (width, height)
74
+ )
75
+
76
+ # Write first frame
77
+ for frame in first_frame:
78
+ frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
79
+ out_video.write(frame_bgr)
80
+
81
+ # Process remaining frames
82
+ for i in tqdm(range(1, duration), desc="Processing video"):
83
+ video_frames = process_one_second(vid, start_sec=i, out_fps=OUT_FPS)
84
+ for frame in video_frames:
85
+ frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
86
+ out_video.write(frame_bgr)
87
+ gc.collect()
88
+
89
+ out_video.release()
90
+
91
+ def copy_non_video_files(input_path, output_path):
92
+ """Copy non-video files and directories from input path to output path."""
93
+ for item in os.listdir(input_path):
94
+ item_path = os.path.join(input_path, item)
95
+ if not item.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')):
96
+ dest_path = os.path.join(output_path, item)
97
+ if os.path.isdir(item_path):
98
+ shutil.copytree(item_path, dest_path)
99
+ else:
100
+ shutil.copy2(item_path, dest_path)
101
+
102
+ def main(input_path, output_path, copy_others=False):
103
+ if not os.path.exists(output_path):
104
+ os.makedirs(output_path)
105
+
106
+ if copy_others:
107
+ copy_non_video_files(input_path, output_path)
108
+
109
+ for video_name in os.listdir(input_path):
110
+ if video_name.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')):
111
+ input_video_path = os.path.join(input_path, video_name)
112
+ output_video_path = os.path.join(output_path, video_name)
113
+ process_video(input_video_path, output_video_path)
114
+
115
+ if __name__ == "__main__":
116
+ parser = argparse.ArgumentParser(description="Process videos to convert them into sketch videos.")
117
+ parser.add_argument("input_path", type=str, help="Path to the input directory containing videos.")
118
+ parser.add_argument("output_path", type=str, help="Path to the output directory for processed videos.")
119
+ parser.add_argument("--copy_others", action="store_true", help="Copy non-video files and directories from input to output.")
120
+
121
+ args = parser.parse_args()
122
+ main(args.input_path, args.output_path, args.copy_others)