tsqn commited on
Commit
c628a1c
·
verified ·
1 Parent(s): 3667bd9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +148 -147
app.py CHANGED
@@ -39,7 +39,7 @@ import utils
39
  #from huggingface_hub import hf_hub_download, snapshot_download
40
  import gc
41
 
42
- device = "cuda" if torch.cuda.is_available() else "cpu"
43
 
44
  #hf_hub_download(repo_id="ai-forever/Real-ESRGAN", filename="RealESRGAN_x4.pth", local_dir="model_real_esran")
45
  #snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife")
@@ -65,14 +65,14 @@ pipe.enable_model_cpu_offload()
65
  pipe.vae.enable_tiling()
66
  pipe.vae.enable_slicing()
67
 
68
- i2v_transformer = CogVideoXTransformer3DModel.from_pretrained(
69
- "THUDM/CogVideoX-5B-I2V", subfolder="transformer", torch_dtype=torch.bfloat16
70
- )
71
- i2v_text_encoder = T5EncoderModel.from_pretrained("THUDM/CogVideoX-5B-I2V", subfolder="text_encoder", torch_dtype=torch.bfloat16)
72
- i2v_vae = AutoencoderKLCogVideoX.from_pretrained("THUDM/CogVideoX-5B-I2V", subfolder="vae", torch_dtype=torch.bfloat16)
73
 
74
- quantize_(i2v_transformer, quantization())
75
- quantize_(i2v_text_encoder, quantization())
76
  # quantize_(i2v_vae, quantization())
77
 
78
  # pipe.transformer.to(memory_format=torch.channels_last)
@@ -100,78 +100,78 @@ Video descriptions must have the same num of words as examples below. Extra word
100
  """
101
 
102
 
103
- def resize_if_unfit(input_video, progress=gr.Progress(track_tqdm=True)):
104
- width, height = get_video_dimensions(input_video)
105
 
106
- if width == 720 and height == 480:
107
- processed_video = input_video
108
- else:
109
- processed_video = center_crop_resize(input_video)
110
- return processed_video
111
 
112
 
113
- def get_video_dimensions(input_video_path):
114
- reader = imageio_ffmpeg.read_frames(input_video_path)
115
- metadata = next(reader)
116
- return metadata["size"]
117
 
118
 
119
- def center_crop_resize(input_video_path, target_width=720, target_height=480):
120
- cap = cv2.VideoCapture(input_video_path)
121
 
122
- orig_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
123
- orig_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
124
- orig_fps = cap.get(cv2.CAP_PROP_FPS)
125
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
126
 
127
- width_factor = target_width / orig_width
128
- height_factor = target_height / orig_height
129
- resize_factor = max(width_factor, height_factor)
130
 
131
- inter_width = int(orig_width * resize_factor)
132
- inter_height = int(orig_height * resize_factor)
133
 
134
- target_fps = 8
135
- ideal_skip = max(0, math.ceil(orig_fps / target_fps) - 1)
136
- skip = min(5, ideal_skip) # Cap at 5
137
 
138
- while (total_frames / (skip + 1)) < 49 and skip > 0:
139
- skip -= 1
140
 
141
- processed_frames = []
142
- frame_count = 0
143
- total_read = 0
144
 
145
- while frame_count < 49 and total_read < total_frames:
146
- ret, frame = cap.read()
147
- if not ret:
148
- break
149
 
150
- if total_read % (skip + 1) == 0:
151
- resized = cv2.resize(frame, (inter_width, inter_height), interpolation=cv2.INTER_AREA)
152
 
153
- start_x = (inter_width - target_width) // 2
154
- start_y = (inter_height - target_height) // 2
155
- cropped = resized[start_y : start_y + target_height, start_x : start_x + target_width]
156
 
157
- processed_frames.append(cropped)
158
- frame_count += 1
159
 
160
- total_read += 1
161
 
162
- cap.release()
163
 
164
- with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file:
165
- temp_video_path = temp_file.name
166
- fourcc = cv2.VideoWriter_fourcc(*"mp4v")
167
- out = cv2.VideoWriter(temp_video_path, fourcc, target_fps, (target_width, target_height))
168
 
169
- for frame in processed_frames:
170
- out.write(frame)
171
 
172
- out.release()
173
 
174
- return temp_video_path
175
 
176
 
177
  # def convert_prompt(prompt: str, retry_times: int = 3) -> str:
@@ -226,9 +226,9 @@ def center_crop_resize(input_video_path, target_width=720, target_height=480):
226
 
227
  def infer(
228
  prompt: str,
229
- image_input: str,
230
- video_input: str,
231
- video_strenght: float,
232
  num_inference_steps: int,
233
  guidance_scale: float,
234
  seed: int = -1,
@@ -237,76 +237,76 @@ def infer(
237
  if seed == -1:
238
  seed = random.randint(0, 2**8 - 1)
239
 
240
- if video_input is not None:
241
- video = load_video(video_input)[:49] # Limit to 49 frames
242
- pipe_video = CogVideoXVideoToVideoPipeline.from_pretrained(
243
- "THUDM/CogVideoX-5B",
244
- transformer=transformer,
245
- vae=vae,
246
- scheduler=pipe.scheduler,
247
- tokenizer=pipe.tokenizer,
248
- text_encoder=text_encoder,
249
- torch_dtype=torch.bfloat16,
250
- ).to(device)
251
 
252
- # pipe_video.enable_model_cpu_offload()
253
- pipe_video.vae.enable_tiling()
254
- pipe_video.vae.enable_slicing()
255
- video_pt = pipe_video(
256
- video=video,
257
- prompt=prompt,
258
- num_inference_steps=num_inference_steps,
259
- num_videos_per_prompt=1,
260
- strength=video_strenght,
261
- use_dynamic_cfg=True,
262
- output_type="pt",
263
- guidance_scale=guidance_scale,
264
- generator=torch.Generator(device="cpu").manual_seed(seed),
265
- ).frames
266
- pipe_video.to("cpu")
267
- del pipe_video
268
- gc.collect()
269
- torch.cuda.empty_cache()
270
- elif image_input is not None:
271
- pipe_image = CogVideoXImageToVideoPipeline.from_pretrained(
272
- "THUDM/CogVideoX-5B-I2V",
273
- transformer=i2v_transformer,
274
- vae=i2v_vae,
275
- scheduler=pipe.scheduler,
276
- tokenizer=pipe.tokenizer,
277
- text_encoder=i2v_text_encoder,
278
- torch_dtype=torch.bfloat16,
279
- ).to(device)
280
- image_input = Image.fromarray(image_input).resize(size=(720, 480)) # Convert to PIL
281
- image = load_image(image_input)
282
- video_pt = pipe_image(
283
- image=image,
284
- prompt=prompt,
285
- num_inference_steps=num_inference_steps,
286
- num_videos_per_prompt=1,
287
- use_dynamic_cfg=True,
288
- output_type="pt",
289
- guidance_scale=guidance_scale,
290
- generator=torch.Generator(device="cpu").manual_seed(seed),
291
- ).frames
292
- pipe_image.to("cpu")
293
- del pipe_image
294
- gc.collect()
295
- torch.cuda.empty_cache()
296
- else:
297
- pipe.to(device)
298
- video_pt = pipe(
299
- prompt=prompt,
300
- num_videos_per_prompt=1,
301
- num_inference_steps=num_inference_steps,
302
- num_frames=24,
303
- use_dynamic_cfg=True,
304
- output_type="pt",
305
- guidance_scale=guidance_scale,
306
- generator=torch.Generator(device="cpu").manual_seed(seed),
307
- ).frames
308
- pipe.to("cpu")
309
- gc.collect()
310
  return (video_pt, seed)
311
 
312
 
@@ -362,13 +362,13 @@ with gr.Blocks() as demo:
362
  """)
363
  with gr.Row():
364
  with gr.Column():
365
- with gr.Accordion("I2V: Image Input (cannot be used simultaneously with video input)", open=False):
366
- image_input = gr.Image(label="Input Image (will be cropped to 720 * 480)")
367
- examples_component_images = gr.Examples(examples_images, inputs=[image_input], cache_examples=False)
368
- with gr.Accordion("V2V: Video Input (cannot be used simultaneously with image input)", open=False):
369
- video_input = gr.Video(label="Input Video (will be cropped to 49 frames, 6 seconds at 8fps)")
370
- strength = gr.Slider(0.1, 1.0, value=0.8, step=0.01, label="Strength")
371
- examples_component_videos = gr.Examples(examples_videos, inputs=[video_input], cache_examples=False)
372
  prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5)
373
 
374
  # with gr.Row():
@@ -465,9 +465,9 @@ with gr.Blocks() as demo:
465
  @spaces.GPU(duration=120)
466
  def generate(
467
  prompt,
468
- image_input,
469
- video_input,
470
- video_strength,
471
  seed_value,
472
  # scale_status,
473
  # rife_status,
@@ -475,10 +475,10 @@ with gr.Blocks() as demo:
475
  ):
476
  latents, seed = infer(
477
  prompt,
478
- image_input,
479
- video_input,
480
- video_strength,
481
- num_inference_steps=20, # Changed from 50
482
  guidance_scale=7.0, # NOT Changed
483
  seed=seed_value,
484
  progress=progress,
@@ -511,13 +511,14 @@ with gr.Blocks() as demo:
511
 
512
  generate_button.click(
513
  generate,
514
- inputs=[prompt, image_input, video_input, strength, seed_param],
 
515
  # inputs=[prompt, image_input, video_input, strength, seed_param, enable_scale, enable_rife],
516
  outputs=[video_output, download_video_button, download_gif_button, seed_text],
517
  )
518
 
519
  # enhance_button.click(enhance_prompt_func, inputs=[prompt], outputs=[prompt])
520
- video_input.upload(resize_if_unfit, inputs=[video_input], outputs=[video_input])
521
 
522
  if __name__ == "__main__":
523
  utils.install_packages()
 
39
  #from huggingface_hub import hf_hub_download, snapshot_download
40
  import gc
41
 
42
+ #device = "cuda" if torch.cuda.is_available() else "cpu"
43
 
44
  #hf_hub_download(repo_id="ai-forever/Real-ESRGAN", filename="RealESRGAN_x4.pth", local_dir="model_real_esran")
45
  #snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife")
 
65
  pipe.vae.enable_tiling()
66
  pipe.vae.enable_slicing()
67
 
68
+ # i2v_transformer = CogVideoXTransformer3DModel.from_pretrained(
69
+ # "THUDM/CogVideoX-5B-I2V", subfolder="transformer", torch_dtype=torch.bfloat16
70
+ # )
71
+ # i2v_text_encoder = T5EncoderModel.from_pretrained("THUDM/CogVideoX-5B-I2V", subfolder="text_encoder", torch_dtype=torch.bfloat16)
72
+ # i2v_vae = AutoencoderKLCogVideoX.from_pretrained("THUDM/CogVideoX-5B-I2V", subfolder="vae", torch_dtype=torch.bfloat16)
73
 
74
+ # quantize_(i2v_transformer, quantization())
75
+ # quantize_(i2v_text_encoder, quantization())
76
  # quantize_(i2v_vae, quantization())
77
 
78
  # pipe.transformer.to(memory_format=torch.channels_last)
 
100
  """
101
 
102
 
103
+ # def resize_if_unfit(input_video, progress=gr.Progress(track_tqdm=True)):
104
+ # width, height = get_video_dimensions(input_video)
105
 
106
+ # if width == 720 and height == 480:
107
+ # processed_video = input_video
108
+ # else:
109
+ # processed_video = center_crop_resize(input_video)
110
+ # return processed_video
111
 
112
 
113
+ # def get_video_dimensions(input_video_path):
114
+ # reader = imageio_ffmpeg.read_frames(input_video_path)
115
+ # metadata = next(reader)
116
+ # return metadata["size"]
117
 
118
 
119
+ # def center_crop_resize(input_video_path, target_width=720, target_height=480):
120
+ # cap = cv2.VideoCapture(input_video_path)
121
 
122
+ # orig_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
123
+ # orig_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
124
+ # orig_fps = cap.get(cv2.CAP_PROP_FPS)
125
+ # total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
126
 
127
+ # width_factor = target_width / orig_width
128
+ # height_factor = target_height / orig_height
129
+ # resize_factor = max(width_factor, height_factor)
130
 
131
+ # inter_width = int(orig_width * resize_factor)
132
+ # inter_height = int(orig_height * resize_factor)
133
 
134
+ # target_fps = 8
135
+ # ideal_skip = max(0, math.ceil(orig_fps / target_fps) - 1)
136
+ # skip = min(5, ideal_skip) # Cap at 5
137
 
138
+ # while (total_frames / (skip + 1)) < 49 and skip > 0:
139
+ # skip -= 1
140
 
141
+ # processed_frames = []
142
+ # frame_count = 0
143
+ # total_read = 0
144
 
145
+ # while frame_count < 49 and total_read < total_frames:
146
+ # ret, frame = cap.read()
147
+ # if not ret:
148
+ # break
149
 
150
+ # if total_read % (skip + 1) == 0:
151
+ # resized = cv2.resize(frame, (inter_width, inter_height), interpolation=cv2.INTER_AREA)
152
 
153
+ # start_x = (inter_width - target_width) // 2
154
+ # start_y = (inter_height - target_height) // 2
155
+ # cropped = resized[start_y : start_y + target_height, start_x : start_x + target_width]
156
 
157
+ # processed_frames.append(cropped)
158
+ # frame_count += 1
159
 
160
+ # total_read += 1
161
 
162
+ # cap.release()
163
 
164
+ # with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file:
165
+ # temp_video_path = temp_file.name
166
+ # fourcc = cv2.VideoWriter_fourcc(*"mp4v")
167
+ # out = cv2.VideoWriter(temp_video_path, fourcc, target_fps, (target_width, target_height))
168
 
169
+ # for frame in processed_frames:
170
+ # out.write(frame)
171
 
172
+ # out.release()
173
 
174
+ # return temp_video_path
175
 
176
 
177
  # def convert_prompt(prompt: str, retry_times: int = 3) -> str:
 
226
 
227
  def infer(
228
  prompt: str,
229
+ # image_input: str,
230
+ # video_input: str,
231
+ # video_strenght: float,
232
  num_inference_steps: int,
233
  guidance_scale: float,
234
  seed: int = -1,
 
237
  if seed == -1:
238
  seed = random.randint(0, 2**8 - 1)
239
 
240
+ # if video_input is not None:
241
+ # video = load_video(video_input)[:49] # Limit to 49 frames
242
+ # pipe_video = CogVideoXVideoToVideoPipeline.from_pretrained(
243
+ # "THUDM/CogVideoX-5B",
244
+ # transformer=transformer,
245
+ # vae=vae,
246
+ # scheduler=pipe.scheduler,
247
+ # tokenizer=pipe.tokenizer,
248
+ # text_encoder=text_encoder,
249
+ # torch_dtype=torch.bfloat16,
250
+ # ).to(device)
251
 
252
+ # # pipe_video.enable_model_cpu_offload()
253
+ # pipe_video.vae.enable_tiling()
254
+ # pipe_video.vae.enable_slicing()
255
+ # video_pt = pipe_video(
256
+ # video=video,
257
+ # prompt=prompt,
258
+ # num_inference_steps=num_inference_steps,
259
+ # num_videos_per_prompt=1,
260
+ # strength=video_strenght,
261
+ # use_dynamic_cfg=True,
262
+ # output_type="pt",
263
+ # guidance_scale=guidance_scale,
264
+ # generator=torch.Generator(device="cpu").manual_seed(seed),
265
+ # ).frames
266
+ # pipe_video.to("cpu")
267
+ # del pipe_video
268
+ # gc.collect()
269
+ # torch.cuda.empty_cache()
270
+ # elif image_input is not None:
271
+ # pipe_image = CogVideoXImageToVideoPipeline.from_pretrained(
272
+ # "THUDM/CogVideoX-5B-I2V",
273
+ # transformer=i2v_transformer,
274
+ # vae=i2v_vae,
275
+ # scheduler=pipe.scheduler,
276
+ # tokenizer=pipe.tokenizer,
277
+ # text_encoder=i2v_text_encoder,
278
+ # torch_dtype=torch.bfloat16,
279
+ # ).to(device)
280
+ # image_input = Image.fromarray(image_input).resize(size=(720, 480)) # Convert to PIL
281
+ # image = load_image(image_input)
282
+ # video_pt = pipe_image(
283
+ # image=image,
284
+ # prompt=prompt,
285
+ # num_inference_steps=num_inference_steps,
286
+ # num_videos_per_prompt=1,
287
+ # use_dynamic_cfg=True,
288
+ # output_type="pt",
289
+ # guidance_scale=guidance_scale,
290
+ # generator=torch.Generator(device="cpu").manual_seed(seed),
291
+ # ).frames
292
+ # pipe_image.to("cpu")
293
+ # del pipe_image
294
+ # gc.collect()
295
+ # torch.cuda.empty_cache()
296
+ # else:
297
+ pipe.to("cpu")
298
+ video_pt = pipe(
299
+ prompt=prompt,
300
+ num_videos_per_prompt=1,
301
+ num_inference_steps=num_inference_steps,
302
+ num_frames=16,
303
+ use_dynamic_cfg=True,
304
+ output_type="pt",
305
+ guidance_scale=guidance_scale,
306
+ generator=torch.Generator(device="cpu").manual_seed(seed),
307
+ ).frames
308
+ pipe.to("cpu")
309
+ gc.collect()
310
  return (video_pt, seed)
311
 
312
 
 
362
  """)
363
  with gr.Row():
364
  with gr.Column():
365
+ # with gr.Accordion("I2V: Image Input (cannot be used simultaneously with video input)", open=False):
366
+ # image_input = gr.Image(label="Input Image (will be cropped to 720 * 480)")
367
+ # examples_component_images = gr.Examples(examples_images, inputs=[image_input], cache_examples=False)
368
+ # with gr.Accordion("V2V: Video Input (cannot be used simultaneously with image input)", open=False):
369
+ # video_input = gr.Video(label="Input Video (will be cropped to 49 frames, 6 seconds at 8fps)")
370
+ # strength = gr.Slider(0.1, 1.0, value=0.8, step=0.01, label="Strength")
371
+ # examples_component_videos = gr.Examples(examples_videos, inputs=[video_input], cache_examples=False)
372
  prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5)
373
 
374
  # with gr.Row():
 
465
  @spaces.GPU(duration=120)
466
  def generate(
467
  prompt,
468
+ # image_input,
469
+ # video_input,
470
+ # video_strength,
471
  seed_value,
472
  # scale_status,
473
  # rife_status,
 
475
  ):
476
  latents, seed = infer(
477
  prompt,
478
+ # image_input,
479
+ # video_input,
480
+ # video_strength,
481
+ num_inference_steps=50, # NOT Changed
482
  guidance_scale=7.0, # NOT Changed
483
  seed=seed_value,
484
  progress=progress,
 
511
 
512
  generate_button.click(
513
  generate,
514
+ inputs=[prompt, seed_param],
515
+ # inputs=[prompt, image_input, video_input, strength, seed_param],
516
  # inputs=[prompt, image_input, video_input, strength, seed_param, enable_scale, enable_rife],
517
  outputs=[video_output, download_video_button, download_gif_button, seed_text],
518
  )
519
 
520
  # enhance_button.click(enhance_prompt_func, inputs=[prompt], outputs=[prompt])
521
+ # video_input.upload(resize_if_unfit, inputs=[video_input], outputs=[video_input])
522
 
523
  if __name__ == "__main__":
524
  utils.install_packages()