Chandans01 commited on
Commit
0ea577e
·
verified ·
1 Parent(s): 5d04448

Delete preprocess.py

Browse files
Files changed (1) hide show
  1. preprocess.py +0 -260
preprocess.py DELETED
@@ -1,260 +0,0 @@
1
- import torch
2
- import torchvision
3
- import json
4
- import os
5
- import random
6
- import numpy as np
7
- import argparse
8
- import decord
9
-
10
- from einops import rearrange
11
- from torchvision import transforms
12
- from tqdm import tqdm
13
- from PIL import Image
14
- from decord import VideoReader, cpu
15
- from transformers import Blip2Processor, Blip2ForConditionalGeneration
16
-
17
- decord.bridge.set_bridge('torch')
18
-
19
- class PreProcessVideos:
20
- def __init__(
21
- self,
22
- config_name,
23
- config_save_name,
24
- video_directory,
25
- random_start_frame,
26
- clip_frame_data,
27
- max_frames,
28
- beam_amount,
29
- prompt_amount,
30
- min_prompt_length,
31
- max_prompt_length,
32
- save_dir
33
- ):
34
-
35
- # Paramaters for parsing videos
36
- self.prompt_amount = prompt_amount
37
- self.video_directory = video_directory
38
- self.random_start_frame = random_start_frame
39
- self.clip_frame_data = clip_frame_data
40
- self.max_frames = max_frames
41
- self.vid_types = (".mp4", ".avi", ".mov", ".webm", ".flv", ".mjpeg")
42
-
43
- # Parameters for BLIP2
44
- self.processor = None
45
- self.blip_model = None
46
- self.beam_amount = beam_amount
47
- self.min_length = min_prompt_length
48
- self.max_length = max_prompt_length
49
-
50
- # Helper parameters
51
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
52
- self.save_dir = save_dir
53
-
54
- # Config parameters
55
- self.config_name = config_name
56
- self.config_save_name = config_save_name
57
-
58
- # Base dict to hold all the data.
59
- # {base_config}
60
- def build_base_config(self):
61
- return {
62
- "name": self.config_name,
63
- "data": []
64
- }
65
-
66
- # Video dict for individual videos.
67
- # {base_config: data -> [{video_path, num_frames, data}]}
68
- def build_video_config(self, video_path: str, num_frames: int):
69
- return {
70
- "video_path": video_path,
71
- "num_frames": num_frames,
72
- "data": []
73
- }
74
-
75
- # Dict for video frames and prompts / captions.
76
- # Gets the frame index, then gets a caption for the that frame and stores it.
77
- # {base_config: data -> [{name, num_frames, data: {frame_index, prompt}}]}
78
- def build_video_data(self, frame_index: int, prompt: str):
79
- return {
80
- "frame_index": frame_index,
81
- "prompt": prompt
82
- }
83
-
84
- # Load BLIP2 for processing
85
- def load_blip(self):
86
- print("Loading BLIP2")
87
-
88
- processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
89
- model = Blip2ForConditionalGeneration.from_pretrained(
90
- "Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16
91
- )
92
- model.to(self.device)
93
-
94
- self.processor = processor
95
- self.blip_model = model
96
-
97
- # Process the frames to get the length and image.
98
- # The limit parameter ensures we don't get near the max frame length.
99
- def video_processor(
100
- self,
101
- video_reader: VideoReader,
102
- num_frames: int,
103
- random_start_frame=True,
104
- frame_num=0
105
- ):
106
-
107
- frame_number = (
108
- random.randrange(0, int(num_frames)) if random_start_frame else frame_num
109
- )
110
- frame = video_reader[frame_number].permute(2,0,1)
111
- image = transforms.ToPILImage()(frame).convert("RGB")
112
- return frame_number, image
113
-
114
- def get_frame_range(self, derterministic):
115
- return range(self.prompt_amount) if self.random_start_frame else derterministic
116
-
117
- def process_blip(self, image: Image):
118
- inputs = self.processor(images=image, return_tensors="pt").to(self.device, torch.float16)
119
- generated_ids = self.blip_model.generate(
120
- **inputs,
121
- num_beams=self.beam_amount,
122
- min_length=self.min_length,
123
- max_length=self.max_length
124
- )
125
- generated_text = self.processor.batch_decode(
126
- generated_ids,
127
- skip_special_tokens=True)[0].strip()
128
-
129
- return generated_text
130
-
131
- def get_out_paths(self, prompt, frame_number):
132
- out_name= f"{prompt}_{str(frame_number)}"
133
- save_path = f"{self.save_dir}/{self.config_save_name}"
134
- save_filepath = f"{save_path}/{out_name}.mp4"
135
-
136
- return out_name, save_path, save_filepath
137
-
138
- def save_train_config(self, config: dict):
139
- os.makedirs(self.save_dir, exist_ok=True)
140
-
141
- save_json = json.dumps(config, indent=4)
142
- save_dir = f"{self.save_dir}/{self.config_save_name}"
143
-
144
- with open(f"{save_dir}.json", 'w') as f:
145
- f.write(save_json)
146
-
147
- def save_video(self, save_path, save_filepath, frames):
148
- os.makedirs(save_path, exist_ok=True)
149
- torchvision.io.write_video(save_filepath, frames, fps=30)
150
-
151
- # Main loop for processing all videos.
152
- def process_videos(self):
153
- self.load_blip()
154
- config = self.build_base_config()
155
-
156
- if not os.path.exists(self.video_directory):
157
- raise ValueError(f"{self.video_directory} does not exist.")
158
-
159
- for _, _, files in tqdm(
160
- os.walk(self.video_directory),
161
- desc=f"Processing videos in {self.video_directory}"
162
- ):
163
- for video in files:
164
- if video.endswith(self.vid_types):
165
- video_path = f"{self.video_directory}/{video}"
166
- video_reader = None
167
- derterministic_range = None
168
- video_len = 0
169
- try:
170
- video_reader = VideoReader(video_path, ctx=cpu(0))
171
- video_len = len(video_reader)
172
- frame_step = abs(video_len // self.prompt_amount)
173
- derterministic_range = range(1, abs(video_len - 1), frame_step)
174
- except:
175
- print(f"Error loading {video_path}. Video may be unsupported or corrupt.")
176
- continue
177
-
178
- # Another try catch block because decord isn't perfect.
179
- try:
180
- num_frames = int(len(video_reader))
181
- video_config = self.build_video_config(video_path, num_frames)
182
-
183
- # Secondary loop that process a specified amount of prompts, selects a random frame, then appends it.
184
- for i in tqdm(
185
- self.get_frame_range(derterministic_range),
186
- desc=f"Processing {os.path.basename(video_path)}"
187
- ):
188
- frame_number, image = self.video_processor(
189
- video_reader,
190
- num_frames,
191
- self.random_start_frame,
192
- frame_num=i
193
- )
194
-
195
- prompt = self.process_blip(image)
196
- video_data = self.build_video_data(frame_number, prompt)
197
-
198
- if self.clip_frame_data:
199
-
200
- # Minimum value, frame number, max value (length of entire video)
201
- max_range = abs(len(video_reader) - 1)
202
- frame_number = i
203
- frame_number = sorted((1, frame_number, max_range))[1]
204
-
205
- frame_range = range(frame_number, max_range)
206
- frame_range_nums= list(frame_range)
207
-
208
- frames = video_reader.get_batch(frame_range_nums[:self.max_frames])
209
-
210
- out_name, save_path, save_filepath = self.get_out_paths(prompt, frame_number)
211
-
212
- self.save_video(save_path, save_filepath, frames)
213
-
214
- video_data['clip_path'] = save_filepath
215
- video_config["data"].append(video_data)
216
-
217
- else:
218
- video_config["data"].append(video_data)
219
-
220
- config['data'].append(video_config)
221
-
222
- except Exception as e:
223
- print(e)
224
- continue
225
- else:
226
- continue
227
-
228
- print(f"Done. Saving train config to {self.save_dir}.")
229
- self.save_train_config(config)
230
-
231
- if __name__ == "__main__":
232
- parser = argparse.ArgumentParser()
233
-
234
- parser.add_argument('--config_name', help="The name of the configuration.", type=str, default='My Config')
235
- parser.add_argument('--config_save_name', help="The name of the config file that's saved.", type=str, default='my_config')
236
- parser.add_argument('--video_directory', help="The directory where your videos are located.", type=str, default='./videos')
237
- parser.add_argument(
238
- '--random_start_frame',
239
- help="Use random start frame when processing videos. Good for long videos where frames have different scenes and meanings.",
240
- action='store_true',
241
- default=False
242
- )
243
- parser.add_argument(
244
- '--clip_frame_data',
245
- help="Save the frames as video clips to HDD/SDD. Videos clips are saved in the same folder as your json directory.",
246
- action='store_true',
247
- default=False
248
- )
249
- parser.add_argument('--max_frames', help="Maximum frames for clips when --clip_frame_data is enabled.", type=int, default=60)
250
- parser.add_argument('--beam_amount', help="Amount for BLIP beam search.", type=int, default=7)
251
- parser.add_argument('--prompt_amount', help="The amount of prompts per video that is processed.", type=int, default=25)
252
- parser.add_argument('--min_prompt_length', help="Minimum words required in prompt.", type=int, default=15)
253
- parser.add_argument('--max_prompt_length', help="Maximum words required in prompt.", type=int, default=30)
254
- parser.add_argument('--save_dir', help="The directory to save the config to.", type=str, default=f"{os.getcwd()}/train_data")
255
-
256
- args = parser.parse_args()
257
-
258
-
259
- processor = PreProcessVideos(**vars(args))
260
- processor.process_videos()