Fabrice-TIERCELIN commited on
Commit
ddbd86c
·
verified ·
1 Parent(s): 6b5cb9f

Update video_super_resolution/scripts/inference_sr.py

Browse files
video_super_resolution/scripts/inference_sr.py CHANGED
@@ -1,56 +1,142 @@
1
- #!/bin/bash
 
 
 
 
 
2
 
3
- # Folder paths
4
- video_folder_path='./input/video'
5
- txt_file_path='./input/text/prompt.txt'
 
6
 
7
- # Get all .mp4 files in the folder using find to handle special characters
8
- mapfile -t mp4_files < <(find "$video_folder_path" -type f -name "*.mp4")
9
 
10
- # Print the list of MP4 files
11
- echo "MP4 files to be processed:"
12
- for mp4_file in "${mp4_files[@]}"; do
13
- echo "$mp4_file"
14
- done
15
 
16
- # Read lines from the text file, skipping empty lines
17
- mapfile -t lines < <(grep -v '^\s*$' "$txt_file_path")
18
 
19
- # List of frame counts
20
- frame_length=32
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- # Debugging output
23
- echo "Number of MP4 files: ${#mp4_files[@]}"
24
- echo "Number of lines in the text file: ${#lines[@]}"
25
 
26
- # Ensure the number of video files matches the number of lines
27
- if [ ${#mp4_files[@]} -ne ${#lines[@]} ]; then
28
- echo "Number of MP4 files and lines in the text file do not match."
29
- exit 1
30
- fi
31
 
32
- # Loop through video files and corresponding lines
33
- for i in "${!mp4_files[@]}"; do
34
- mp4_file="${mp4_files[$i]}"
35
- line="${lines[$i]}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- # Extract the filename without the extension
38
- file_name=$(basename "$mp4_file" .mp4)
 
39
 
40
- echo "Processing video file: $mp4_file with prompt: $line"
41
-
42
- # Run Python script with parameters
43
- python \
44
- ./video_super_resolution/scripts/inference_sr.py \
45
- --solver_mode 'fast' \
46
- --steps 15 \
47
- --input_path "${mp4_file}" \
48
- --model_path /mnt/bn/videodataset/VSR/pretrained_models/STAR/heavy_deg.pt \
49
- --prompt "${line}" \
50
- --upscale 4 \
51
- --max_chunk_len ${frame_length} \
52
- --file_name "${file_name}.mp4" \
53
- --save_dir ./results
54
- done
55
-
56
- echo "All videos processed successfully."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from argparse import ArgumentParser, Namespace
4
+ import json
5
+ from typing import Any, Dict, List, Mapping, Tuple
6
+ from easydict import EasyDict
7
 
8
+ from video_to_video.video_to_video_model import VideoToVideo_sr
9
+ from video_to_video.utils.seed import setup_seed
10
+ from video_to_video.utils.logger import get_logger
11
+ from video_super_resolution.color_fix import adain_color_fix
12
 
13
+ from inference_utils import *
 
14
 
15
+ logger = get_logger()
 
 
 
 
16
 
 
 
17
 
18
+ class STAR_sr():
19
+ def __init__(self,
20
+ result_dir='./results/',
21
+ file_name='000_video.mp4',
22
+ model_path='./pretrained_weight',
23
+ solver_mode='fast',
24
+ steps=15,
25
+ guide_scale=7.5,
26
+ upscale=4,
27
+ max_chunk_len=32,
28
+ variant_info=None,
29
+ chunk_size=3,
30
+ ):
31
+ self.model_path=model_path
32
+ logger.info('checkpoint_path: {}'.format(self.model_path))
33
 
34
+ self.result_dir = result_dir
35
+ self.file_name = file_name
36
+ os.makedirs(self.result_dir, exist_ok=True)
37
 
38
+ model_cfg = EasyDict(__name__='model_cfg')
39
+ model_cfg.model_path = self.model_path
40
+ model_cfg.chunk_size = chunk_size
41
+ self.model = VideoToVideo_sr(model_cfg)
 
42
 
43
+ steps = 15 if solver_mode == 'fast' else steps
44
+ self.solver_mode=solver_mode
45
+ self.steps=steps
46
+ self.guide_scale=guide_scale
47
+ self.upscale = upscale
48
+ self.max_chunk_len=max_chunk_len
49
+ self.variant_info=variant_info
50
+
51
+ def enhance_a_video(self, video_path, prompt):
52
+ logger.info('input video path: {}'.format(video_path))
53
+ text = prompt
54
+ logger.info('text: {}'.format(text))
55
+ caption = text + self.model.positive_prompt
56
+
57
+ input_frames, input_fps = load_video(video_path)
58
+ in_f_num = len(input_frames)
59
+ logger.info('input frames length: {}'.format(in_f_num))
60
+ logger.info('input fps: {}'.format(input_fps))
61
+
62
+ video_data = preprocess(input_frames)
63
+ _, _, h, w = video_data.shape
64
+ logger.info('input resolution: {}'.format((h, w)))
65
+ target_h, target_w = h * self.upscale, w * self.upscale # adjust_resolution(h, w, up_scale=4)
66
+ logger.info('target resolution: {}'.format((target_h, target_w)))
67
+
68
+ pre_data = {'video_data': video_data, 'y': caption}
69
+ pre_data['target_res'] = (target_h, target_w)
70
+
71
+ total_noise_levels = 900
72
+ setup_seed(666)
73
+
74
+ with torch.no_grad():
75
+ data_tensor = collate_fn(pre_data, 'cuda:0')
76
+ output = self.model.test(data_tensor, total_noise_levels, steps=self.steps, \
77
+ solver_mode=self.solver_mode, guide_scale=self.guide_scale, \
78
+ max_chunk_len=self.max_chunk_len
79
+ )
80
+
81
+ output = tensor2vid(output)
82
+
83
+ # Using color fix
84
+ output = adain_color_fix(output, video_data)
85
+
86
+ save_video(output, self.result_dir, self.file_name, fps=input_fps)
87
+ return os.path.join(self.result_dir, self.file_name)
88
 
89
+
90
+ def parse_args():
91
+ parser = ArgumentParser()
92
 
93
+ parser.add_argument("--input_path", required=True, type=str, help="input video path")
94
+ parser.add_argument("--save_dir", type=str, default='results', help="save directory")
95
+ parser.add_argument("--file_name", type=str, help="file name")
96
+ parser.add_argument("--model_path", type=str, default='./pretrained_weight/I2VGen-XL-based/heavy_deg.pt', help="model path")
97
+ parser.add_argument("--prompt", type=str, default='a good video', help="prompt")
98
+ parser.add_argument("--upscale", type=int, default=4, help='up-scale')
99
+ parser.add_argument("--max_chunk_len", type=int, default=32, help='max_chunk_len')
100
+ parser.add_argument("--variant_info", type=str, default=None, help='information of inference strategy')
101
+
102
+ parser.add_argument("--cfg", type=float, default=7.5)
103
+ parser.add_argument("--solver_mode", type=str, default='fast', help='fast | normal')
104
+ parser.add_argument("--steps", type=int, default=15)
105
+
106
+ return parser.parse_args()
107
+
108
+ def main():
109
+
110
+ args = parse_args()
111
+
112
+ input_path = args.input_path
113
+ prompt = args.prompt
114
+ model_path = args.model_path
115
+ save_dir = args.save_dir
116
+ file_name = args.file_name
117
+ upscale = args.upscale
118
+ max_chunk_len = args.max_chunk_len
119
+
120
+ steps = args.steps
121
+ solver_mode = args.solver_mode
122
+ guide_scale = args.cfg
123
+
124
+ assert solver_mode in ('fast', 'normal')
125
+
126
+ star_sr = STAR_sr(
127
+ result_dir=save_dir,
128
+ file_name=file_name, # new added
129
+ model_path=model_path,
130
+ solver_mode=solver_mode,
131
+ steps=steps,
132
+ guide_scale=guide_scale,
133
+ upscale=upscale,
134
+ max_chunk_len=max_chunk_len,
135
+ variant_info=None,
136
+ )
137
+
138
+ star_sr.enhance_a_video(input_path, prompt)
139
+
140
+
141
+ if __name__ == '__main__':
142
+ main()