tsqn commited on
Commit
2a75bce
Β·
verified Β·
1 Parent(s): 134567e

Delete inference.py

Browse files
Files changed (1) hide show
  1. inference.py +0 -369
inference.py DELETED
@@ -1,369 +0,0 @@
1
- import torch
2
- from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
3
- from xora.models.transformers.transformer3d import Transformer3DModel
4
- from xora.models.transformers.symmetric_patchifier import SymmetricPatchifier
5
- from xora.schedulers.rf import RectifiedFlowScheduler
6
- from xora.pipelines.pipeline_xora_video import XoraVideoPipeline
7
- from pathlib import Path
8
- from transformers import T5EncoderModel, T5Tokenizer
9
- import safetensors.torch
10
- import json
11
- import argparse
12
- from xora.utils.conditioning_method import ConditioningMethod
13
- import os
14
- import numpy as np
15
- import cv2
16
- from PIL import Image
17
- import random
18
-
19
- RECOMMENDED_RESOLUTIONS = [
20
- (704, 1216, 41),
21
- (704, 1088, 49),
22
- (640, 1056, 57),
23
- (608, 992, 65),
24
- (608, 896, 73),
25
- (544, 896, 81),
26
- (544, 832, 89),
27
- (512, 800, 97),
28
- (512, 768, 97),
29
- (480, 800, 105),
30
- (480, 736, 113),
31
- (480, 704, 121),
32
- (448, 704, 129),
33
- (448, 672, 137),
34
- (416, 640, 153),
35
- (384, 672, 161),
36
- (384, 640, 169),
37
- (384, 608, 177),
38
- (384, 576, 185),
39
- (352, 608, 193),
40
- (352, 576, 201),
41
- (352, 544, 209),
42
- (352, 512, 225),
43
- (352, 512, 233),
44
- (320, 544, 241),
45
- (320, 512, 249),
46
- (320, 512, 257),
47
- ]
48
-
49
-
50
- def load_vae(vae_dir):
51
- vae_ckpt_path = vae_dir / "vae_diffusion_pytorch_model.safetensors"
52
- vae_config_path = vae_dir / "config.json"
53
- with open(vae_config_path, "r") as f:
54
- vae_config = json.load(f)
55
- vae = CausalVideoAutoencoder.from_config(vae_config)
56
- vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
57
- vae.load_state_dict(vae_state_dict)
58
- if torch.cuda.is_available():
59
- vae = vae.cuda()
60
- return vae.to(torch.bfloat16)
61
-
62
-
63
- def load_unet(unet_dir):
64
- unet_ckpt_path = unet_dir / "unet_diffusion_pytorch_model.safetensors"
65
- unet_config_path = unet_dir / "config.json"
66
- transformer_config = Transformer3DModel.load_config(unet_config_path)
67
- transformer = Transformer3DModel.from_config(transformer_config)
68
- unet_state_dict = safetensors.torch.load_file(unet_ckpt_path)
69
- transformer.load_state_dict(unet_state_dict, strict=True)
70
- if torch.cuda.is_available():
71
- transformer = transformer.cuda()
72
- return transformer
73
-
74
-
75
- def load_scheduler(scheduler_dir):
76
- scheduler_config_path = scheduler_dir / "scheduler_config.json"
77
- scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
78
- return RectifiedFlowScheduler.from_config(scheduler_config)
79
-
80
-
81
- def center_crop_and_resize(frame, target_height, target_width):
82
- h, w, _ = frame.shape
83
- aspect_ratio_target = target_width / target_height
84
- aspect_ratio_frame = w / h
85
- if aspect_ratio_frame > aspect_ratio_target:
86
- new_width = int(h * aspect_ratio_target)
87
- x_start = (w - new_width) // 2
88
- frame_cropped = frame[:, x_start : x_start + new_width]
89
- else:
90
- new_height = int(w / aspect_ratio_target)
91
- y_start = (h - new_height) // 2
92
- frame_cropped = frame[y_start : y_start + new_height, :]
93
- frame_resized = cv2.resize(frame_cropped, (target_width, target_height))
94
- return frame_resized
95
-
96
-
97
- def load_video_to_tensor_with_resize(video_path, target_height, target_width):
98
- cap = cv2.VideoCapture(video_path)
99
- frames = []
100
- while True:
101
- ret, frame = cap.read()
102
- if not ret:
103
- break
104
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
105
- if target_height is not None:
106
- frame_resized = center_crop_and_resize(
107
- frame_rgb, target_height, target_width
108
- )
109
- else:
110
- frame_resized = frame_rgb
111
- frames.append(frame_resized)
112
- cap.release()
113
- video_np = (np.array(frames) / 127.5) - 1.0
114
- video_tensor = torch.tensor(video_np).permute(3, 0, 1, 2).float()
115
- return video_tensor
116
-
117
-
118
- def load_image_to_tensor_with_resize(image_path, target_height=512, target_width=768):
119
- image = Image.open(image_path).convert("RGB")
120
- image_np = np.array(image)
121
- frame_resized = center_crop_and_resize(image_np, target_height, target_width)
122
- frame_tensor = torch.tensor(frame_resized).permute(2, 0, 1).float()
123
- frame_tensor = (frame_tensor / 127.5) - 1.0
124
- # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
125
- return frame_tensor.unsqueeze(0).unsqueeze(2)
126
-
127
-
128
- def main():
129
- parser = argparse.ArgumentParser(
130
- description="Load models from separate directories and run the pipeline."
131
- )
132
-
133
- # Directories
134
- parser.add_argument(
135
- "--ckpt_dir",
136
- type=str,
137
- required=True,
138
- help="Path to the directory containing unet, vae, and scheduler subdirectories",
139
- )
140
- parser.add_argument(
141
- "--input_video_path",
142
- type=str,
143
- help="Path to the input video file (first frame used)",
144
- )
145
- parser.add_argument(
146
- "--input_image_path", type=str, help="Path to the input image file"
147
- )
148
- parser.add_argument(
149
- "--output_path",
150
- type=str,
151
- default=None,
152
- help="Path to save output video, if None will save in working directory.",
153
- )
154
- parser.add_argument("--seed", type=int, default="171198")
155
-
156
- # Pipeline parameters
157
- parser.add_argument(
158
- "--num_inference_steps", type=int, default=40, help="Number of inference steps"
159
- )
160
- parser.add_argument(
161
- "--num_images_per_prompt",
162
- type=int,
163
- default=1,
164
- help="Number of images per prompt",
165
- )
166
- parser.add_argument(
167
- "--guidance_scale",
168
- type=float,
169
- default=3,
170
- help="Guidance scale for the pipeline",
171
- )
172
- parser.add_argument(
173
- "--height",
174
- type=int,
175
- default=None,
176
- help="Height of the output video frames. Optional if an input image provided.",
177
- )
178
- parser.add_argument(
179
- "--width",
180
- type=int,
181
- default=None,
182
- help="Width of the output video frames. If None will infer from input image.",
183
- )
184
- parser.add_argument(
185
- "--num_frames",
186
- type=int,
187
- default=121,
188
- help="Number of frames to generate in the output video",
189
- )
190
- parser.add_argument(
191
- "--frame_rate", type=int, default=25, help="Frame rate for the output video"
192
- )
193
-
194
- parser.add_argument(
195
- "--bfloat16",
196
- action="store_true",
197
- help="Denoise in bfloat16",
198
- )
199
-
200
- # Prompts
201
- parser.add_argument(
202
- "--prompt",
203
- type=str,
204
- help="Text prompt to guide generation",
205
- )
206
- parser.add_argument(
207
- "--negative_prompt",
208
- type=str,
209
- default="worst quality, inconsistent motion, blurry, jittery, distorted",
210
- help="Negative prompt for undesired features",
211
- )
212
- parser.add_argument(
213
- "--custom_resolution",
214
- action="store_true",
215
- default=False,
216
- help="Enable custom resolution (not in recommneded resolutions) if specified (default: False)",
217
- )
218
-
219
- args = parser.parse_args()
220
-
221
- if args.input_image_path is None and args.input_video_path is None:
222
- assert (
223
- args.height is not None and args.width is not None
224
- ), "Must enter height and width for text to image generation."
225
-
226
- # Load media (video or image)
227
- if args.input_video_path:
228
- media_items = load_video_to_tensor_with_resize(
229
- args.input_video_path, args.height, args.width
230
- ).unsqueeze(0)
231
- elif args.input_image_path:
232
- media_items = load_image_to_tensor_with_resize(
233
- args.input_image_path, args.height, args.width
234
- )
235
- else:
236
- media_items = None
237
-
238
- height = args.height if args.height else media_items.shape[-2]
239
- width = args.width if args.width else media_items.shape[-1]
240
- assert height % 32 == 0, f"Height ({height}) should be divisible by 32."
241
- assert width % 32 == 0, f"Width ({width}) should be divisible by 32."
242
- assert (
243
- height,
244
- width,
245
- args.num_frames,
246
- ) in RECOMMENDED_RESOLUTIONS or args.custom_resolution, f"The selected resolution + num frames combination is not supported, results would be suboptimal. Supported (h,w,f) are: {RECOMMENDED_RESOLUTIONS}. Use --custom_resolution to enable working with this resolution."
247
-
248
- # Paths for the separate mode directories
249
- ckpt_dir = Path(args.ckpt_dir)
250
- unet_dir = ckpt_dir / "unet"
251
- vae_dir = ckpt_dir / "vae"
252
- scheduler_dir = ckpt_dir / "scheduler"
253
-
254
- # Load models
255
- vae = load_vae(vae_dir)
256
- unet = load_unet(unet_dir)
257
- scheduler = load_scheduler(scheduler_dir)
258
- patchifier = SymmetricPatchifier(patch_size=1)
259
- text_encoder = T5EncoderModel.from_pretrained(
260
- "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder"
261
- )
262
- if torch.cuda.is_available():
263
- text_encoder = text_encoder.to("cuda")
264
- tokenizer = T5Tokenizer.from_pretrained(
265
- "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer"
266
- )
267
-
268
- if args.bfloat16 and unet.dtype != torch.bfloat16:
269
- unet = unet.to(torch.bfloat16)
270
-
271
- # Use submodels for the pipeline
272
- submodel_dict = {
273
- "transformer": unet,
274
- "patchifier": patchifier,
275
- "text_encoder": text_encoder,
276
- "tokenizer": tokenizer,
277
- "scheduler": scheduler,
278
- "vae": vae,
279
- }
280
-
281
- pipeline = XoraVideoPipeline(**submodel_dict)
282
- if torch.cuda.is_available():
283
- pipeline = pipeline.to("cuda")
284
-
285
- # Prepare input for the pipeline
286
- sample = {
287
- "prompt": args.prompt,
288
- "prompt_attention_mask": None,
289
- "negative_prompt": args.negative_prompt,
290
- "negative_prompt_attention_mask": None,
291
- "media_items": media_items,
292
- }
293
-
294
- random.seed(args.seed)
295
- np.random.seed(args.seed)
296
- torch.manual_seed(args.seed)
297
- if torch.cuda.is_available():
298
- torch.cuda.manual_seed(args.seed)
299
-
300
- generator = torch.Generator(
301
- device="cuda" if torch.cuda.is_available() else "cpu"
302
- ).manual_seed(args.seed)
303
-
304
- images = pipeline(
305
- num_inference_steps=args.num_inference_steps,
306
- num_images_per_prompt=args.num_images_per_prompt,
307
- guidance_scale=args.guidance_scale,
308
- generator=generator,
309
- output_type="pt",
310
- callback_on_step_end=None,
311
- height=height,
312
- width=width,
313
- num_frames=args.num_frames,
314
- frame_rate=args.frame_rate,
315
- **sample,
316
- is_video=True,
317
- vae_per_channel_normalize=True,
318
- conditioning_method=(
319
- ConditioningMethod.FIRST_FRAME
320
- if media_items is not None
321
- else ConditioningMethod.UNCONDITIONAL
322
- ),
323
- mixed_precision=not args.bfloat16,
324
- ).images
325
-
326
- # Save output video
327
- def get_unique_filename(base, ext, dir=".", index_range=1000):
328
- for i in range(index_range):
329
- filename = os.path.join(dir, f"{base}_{i}{ext}")
330
- if not os.path.exists(filename):
331
- return filename
332
- raise FileExistsError(
333
- f"Could not find a unique filename after {index_range} attempts."
334
- )
335
-
336
- for i in range(images.shape[0]):
337
- # Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C
338
- video_np = images[i].permute(1, 2, 3, 0).cpu().float().numpy()
339
- # Unnormalizing images to [0, 255] range
340
- video_np = (video_np * 255).astype(np.uint8)
341
- fps = args.frame_rate
342
- height, width = video_np.shape[1:3]
343
- if video_np.shape[0] == 1:
344
- output_filename = (
345
- args.output_path
346
- if args.output_path is not None
347
- else get_unique_filename(f"image_output_{i}", ".png", ".")
348
- )
349
- cv2.imwrite(
350
- output_filename, video_np[0][..., ::-1]
351
- ) # Save single frame as image
352
- else:
353
- output_filename = (
354
- args.output_path
355
- if args.output_path is not None
356
- else get_unique_filename(f"video_output_{i}", ".mp4", ".")
357
- )
358
-
359
- out = cv2.VideoWriter(
360
- output_filename, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height)
361
- )
362
-
363
- for frame in video_np[..., ::-1]:
364
- out.write(frame)
365
- out.release()
366
-
367
-
368
- if __name__ == "__main__":
369
- main()