alexnasa commited on
Commit
69d77bb
·
verified ·
1 Parent(s): 5af1c18

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -35
app.py CHANGED
@@ -299,6 +299,7 @@ class WanInferencePipeline(nn.Module):
299
  def get_times(self, prompt,
300
  image_path=None,
301
  audio_path=None,
 
302
  seq_len=101, # not used while audio_path is not None
303
  height=720,
304
  width=720,
@@ -321,11 +322,11 @@ class WanInferencePipeline(nn.Module):
321
  image = self.transform(image).unsqueeze(0).to(dtype=self.dtype)
322
 
323
  _, _, h, w = image.shape
324
- select_size = match_size(getattr( self.args, f'image_sizes_{ self.args.max_hw}'), h, w)
325
  image = resize_pad(image, (h, w), select_size)
326
  image = image * 2.0 - 1.0
327
  image = image[:, :, None]
328
-
329
  else:
330
  image = None
331
  select_size = [height, width]
@@ -373,6 +374,7 @@ class WanInferencePipeline(nn.Module):
373
  def forward(self, prompt,
374
  image_path=None,
375
  audio_path=None,
 
376
  seq_len=101, # not used while audio_path is not None
377
  height=720,
378
  width=720,
@@ -394,17 +396,15 @@ class WanInferencePipeline(nn.Module):
394
  image = self.transform(image).unsqueeze(0).to(self.device, dtype=self.dtype)
395
 
396
  _, _, h, w = image.shape
397
- select_size = match_size(getattr(self.args, f'image_sizes_{self.args.max_hw}'), h, w)
398
  image = resize_pad(image, (h, w), select_size)
399
  image = image * 2.0 - 1.0
400
  image = image[:, :, None]
401
 
402
  else:
403
  image = None
404
- select_size = [height, width]
405
- # L = int(self.args.max_tokens * 16 * 16 * 4 / select_size[0] / select_size[1])
406
- # L = L // 4 * 4 + 1 if L % 4 != 0 else L - 3 # video frames
407
- # T = (L + 3) // 4 # latent frames
408
 
409
  # step 1: numerator and denominator as ints
410
  num = args.max_tokens * 16 * 16 * 4
@@ -414,9 +414,6 @@ class WanInferencePipeline(nn.Module):
414
  L0 = num // den # exact floor division, no float in sight
415
 
416
  # step 3: make it ≡ 1 mod 4
417
- # if L0 % 4 == 1, keep L0;
418
- # otherwise subtract the difference so that (L0 - diff) % 4 == 1,
419
- # but ensure the result stays positive.
420
  diff = (L0 - 1) % 4
421
  L = L0 - diff
422
  if L < 1:
@@ -555,7 +552,7 @@ ADAPTIVE_PROMPT_TEMPLATES = [
555
  "A realistic video of a person speaking and sometimes looking directly to the camera and moving their eyes and pupils and head accordingly and turning and looking at the camera and looking away from the camera based on their movements with dynamic and rhythmic and extensive hand gestures that complement their speech. Their hands are clearly visible, independent, and unobstructed. Their facial expressions are expressive and full of emotion, enhancing the delivery. The camera remains steady, capturing sharp, clear movements and a focused, engaging presence."
556
  ]
557
 
558
- def slider_value_change(image_path, audio_path, text, num_steps, session_state, adaptive_text):
559
 
560
  if adaptive_text:
561
 
@@ -569,20 +566,20 @@ def slider_value_change(image_path, audio_path, text, num_steps, session_state,
569
  else:
570
  text = ADAPTIVE_PROMPT_TEMPLATES[1]
571
 
572
- return update_generate_button(image_path, audio_path, text, num_steps, session_state), text
573
 
574
 
575
- def update_generate_button(image_path, audio_path, text, num_steps, session_state):
576
 
577
  if image_path is None or audio_path is None:
578
  return gr.update(value="⌚ Zero GPU Required: --")
579
 
580
- duration_s = get_duration(image_path, audio_path, text, num_steps, session_state, None)
581
  duration_m = duration_s / 60
582
 
583
  return gr.update(value=f"⌚ Zero GPU Required: ~{duration_s}.0s ({duration_m:.1f} mins)")
584
 
585
- def get_duration(image_path, audio_path, text, num_steps, session_id, progress):
586
 
587
  if image_path is None:
588
  gr.Info("Step1: Please Provide an Image or Choose from Image Samples")
@@ -601,6 +598,7 @@ def get_duration(image_path, audio_path, text, num_steps, session_id, progress):
601
  prompt=text,
602
  image_path=image_path,
603
  audio_path=audio_path,
 
604
  seq_len=args.seq_len,
605
  num_steps=num_steps
606
  )
@@ -615,7 +613,7 @@ def get_duration(image_path, audio_path, text, num_steps, session_id, progress):
615
 
616
  return int(duration_s)
617
 
618
- def preprocess_img(input_image_path, raw_image_path, session_id = None):
619
 
620
  if session_id is None:
621
  session_id = uuid.uuid4().hex
@@ -631,7 +629,7 @@ def preprocess_img(input_image_path, raw_image_path, session_id = None):
631
  image = inferpipe.transform(image).unsqueeze(0).to(dtype=inferpipe.dtype)
632
 
633
  _, _, h, w = image.shape
634
- select_size = match_size(getattr( args, f'image_sizes_{ args.max_hw}'), h, w)
635
  image = resize_pad(image, (h, w), select_size)
636
  image = image * 2.0 - 1.0
637
  image = image[:, :, None]
@@ -649,17 +647,16 @@ def preprocess_img(input_image_path, raw_image_path, session_id = None):
649
 
650
  def infer_example(image_path, audio_path, text, num_steps, raw_image_path, session_id = None, progress=gr.Progress(track_tqdm=True),):
651
 
652
- current_image_size = args.image_sizes_720
653
- args.image_sizes_720 = [[720, 400]]
654
-
655
- result = infer(image_path, audio_path, text, num_steps, session_id, progress)
656
 
657
- args.image_sizes_720 = current_image_size
 
658
 
659
  return result
660
 
661
  @spaces.GPU(duration=get_duration)
662
- def infer(image_path, audio_path, text, num_steps, session_id = None, progress=gr.Progress(track_tqdm=True),):
663
 
664
  if image_path is None:
665
 
@@ -694,6 +691,7 @@ def infer(image_path, audio_path, text, num_steps, session_id = None, progress=g
694
  prompt=text,
695
  image_path=image_path,
696
  audio_path=input_audio_path,
 
697
  seq_len=args.seq_len,
698
  num_steps=num_steps
699
  )
@@ -713,7 +711,8 @@ def infer(image_path, audio_path, text, num_steps, session_id = None, progress=g
713
 
714
  def apply_image(request):
715
  print('image applied')
716
- return request, None
 
717
 
718
  def apply_audio(request):
719
  print('audio applied')
@@ -739,13 +738,15 @@ def orientation_changed(session_id, evt: gr.EventData):
739
  detail = getattr(evt, "data", None) or getattr(evt, "_data", {}) or {}
740
 
741
  if detail['value'] == "9:16":
742
- args.image_sizes_720 = [[720, 400]]
743
  elif detail['value'] == "1:1":
744
- args.image_sizes_720 = [[720, 720]]
745
  elif detail['value'] == "16:9":
746
- args.image_sizes_720 = [[400, 720]]
 
 
747
 
748
- print(f'{session_id} has {args.image_sizes_720} orientation')
749
 
750
  def clear_raw_image():
751
  return ''
@@ -819,6 +820,7 @@ css = """
819
  with gr.Blocks(css=css) as demo:
820
 
821
  session_state = gr.State()
 
822
  demo.load(start_session, outputs=[session_state])
823
 
824
 
@@ -936,7 +938,9 @@ with gr.Blocks(css=css) as demo:
936
  ],
937
  label="Image Samples",
938
  inputs=[image_input],
939
- cache_examples=False
 
 
940
  )
941
 
942
  audio_examples = gr.Examples(
@@ -964,7 +968,7 @@ with gr.Blocks(css=css) as demo:
964
 
965
  infer_btn.click(
966
  fn=infer,
967
- inputs=[image_input, audio_input, text_input, num_steps, session_state],
968
  outputs=[output_video]
969
  )
970
 
@@ -981,12 +985,12 @@ with gr.Blocks(css=css) as demo:
981
  inputs=[audio_input, limit_on, session_state],
982
  outputs=[audio_input],
983
  )
984
- image_input.orientation(fn=orientation_changed, inputs=[session_state]).then(fn=preprocess_img, inputs=[image_input, raw_img_text, session_state], outputs=[image_input, raw_img_text])
985
  image_input.clear(fn=clear_raw_image, outputs=[raw_img_text])
986
- image_input.upload(fn=preprocess_img, inputs=[image_input, raw_img_text, session_state], outputs=[image_input, raw_img_text])
987
- image_input.change(fn=update_generate_button, inputs=[image_input, audio_input, text_input, num_steps, session_state], outputs=[time_required])
988
- audio_input.change(fn=update_generate_button, inputs=[image_input, audio_input, text_input, num_steps, session_state], outputs=[time_required])
989
- num_steps.change(fn=slider_value_change, inputs=[image_input, audio_input, text_input, num_steps, session_state, adaptive_text], outputs=[time_required, text_input])
990
  adaptive_text.change(fn=check_box_clicked, inputs=[adaptive_text], outputs=[text_input])
991
  audio_input.upload(fn=apply_audio, inputs=[audio_input], outputs=[audio_input]
992
  ).then(
 
299
  def get_times(self, prompt,
300
  image_path=None,
301
  audio_path=None,
302
+ orientation_state = None,
303
  seq_len=101, # not used while audio_path is not None
304
  height=720,
305
  width=720,
 
322
  image = self.transform(image).unsqueeze(0).to(dtype=self.dtype)
323
 
324
  _, _, h, w = image.shape
325
+ select_size = match_size(orientation_state, h, w)
326
  image = resize_pad(image, (h, w), select_size)
327
  image = image * 2.0 - 1.0
328
  image = image[:, :, None]
329
+
330
  else:
331
  image = None
332
  select_size = [height, width]
 
374
  def forward(self, prompt,
375
  image_path=None,
376
  audio_path=None,
377
+ orientation_state = None,
378
  seq_len=101, # not used while audio_path is not None
379
  height=720,
380
  width=720,
 
396
  image = self.transform(image).unsqueeze(0).to(self.device, dtype=self.dtype)
397
 
398
  _, _, h, w = image.shape
399
+ select_size = match_size(orientation_state, h, w)
400
  image = resize_pad(image, (h, w), select_size)
401
  image = image * 2.0 - 1.0
402
  image = image[:, :, None]
403
 
404
  else:
405
  image = None
406
+ h = height
407
+ w = width
 
 
408
 
409
  # step 1: numerator and denominator as ints
410
  num = args.max_tokens * 16 * 16 * 4
 
414
  L0 = num // den # exact floor division, no float in sight
415
 
416
  # step 3: make it ≡ 1 mod 4
 
 
 
417
  diff = (L0 - 1) % 4
418
  L = L0 - diff
419
  if L < 1:
 
552
  "A realistic video of a person speaking and sometimes looking directly to the camera and moving their eyes and pupils and head accordingly and turning and looking at the camera and looking away from the camera based on their movements with dynamic and rhythmic and extensive hand gestures that complement their speech. Their hands are clearly visible, independent, and unobstructed. Their facial expressions are expressive and full of emotion, enhancing the delivery. The camera remains steady, capturing sharp, clear movements and a focused, engaging presence."
553
  ]
554
 
555
+ def slider_value_change(image_path, audio_path, orientation_state, text, num_steps, session_state, adaptive_text):
556
 
557
  if adaptive_text:
558
 
 
566
  else:
567
  text = ADAPTIVE_PROMPT_TEMPLATES[1]
568
 
569
+ return update_generate_button(image_path, audio_path, orientation_state, text, num_steps, session_state), text
570
 
571
 
572
+ def update_generate_button(image_path, audio_path, orientation_state, text, num_steps, session_state):
573
 
574
  if image_path is None or audio_path is None:
575
  return gr.update(value="⌚ Zero GPU Required: --")
576
 
577
+ duration_s = get_duration(image_path, audio_path, text, orientation_state, num_steps, session_state, None)
578
  duration_m = duration_s / 60
579
 
580
  return gr.update(value=f"⌚ Zero GPU Required: ~{duration_s}.0s ({duration_m:.1f} mins)")
581
 
582
+ def get_duration(image_path, audio_path, text, orientation_state, num_steps, session_id, progress):
583
 
584
  if image_path is None:
585
  gr.Info("Step1: Please Provide an Image or Choose from Image Samples")
 
598
  prompt=text,
599
  image_path=image_path,
600
  audio_path=audio_path,
601
+ orientation_state= orientation_state,
602
  seq_len=args.seq_len,
603
  num_steps=num_steps
604
  )
 
613
 
614
  return int(duration_s)
615
 
616
+ def preprocess_img(input_image_path, raw_image_path, orientation_state, session_id = None):
617
 
618
  if session_id is None:
619
  session_id = uuid.uuid4().hex
 
629
  image = inferpipe.transform(image).unsqueeze(0).to(dtype=inferpipe.dtype)
630
 
631
  _, _, h, w = image.shape
632
+ select_size = match_size(orientation_state, h, w)
633
  image = resize_pad(image, (h, w), select_size)
634
  image = image * 2.0 - 1.0
635
  image = image[:, :, None]
 
647
 
648
  def infer_example(image_path, audio_path, text, num_steps, raw_image_path, session_id = None, progress=gr.Progress(track_tqdm=True),):
649
 
650
+ if session_id is None:
651
+ session_id = uuid.uuid4().hex
 
 
652
 
653
+ image_path, _ = preprocess_img(image_path, image_path, [[720, 400]], session_id)
654
+ result = infer(image_path, audio_path, text, [[720, 400]], num_steps, session_id, progress)
655
 
656
  return result
657
 
658
  @spaces.GPU(duration=get_duration)
659
+ def infer(image_path, audio_path, text, orientation_state, num_steps, session_id = None, progress=gr.Progress(track_tqdm=True),):
660
 
661
  if image_path is None:
662
 
 
691
  prompt=text,
692
  image_path=image_path,
693
  audio_path=input_audio_path,
694
+ orientation_state=orientation_state,
695
  seq_len=args.seq_len,
696
  num_steps=num_steps
697
  )
 
711
 
712
  def apply_image(request):
713
  print('image applied')
714
+
715
+ return request, request
716
 
717
  def apply_audio(request):
718
  print('audio applied')
 
738
  detail = getattr(evt, "data", None) or getattr(evt, "_data", {}) or {}
739
 
740
  if detail['value'] == "9:16":
741
+ orientation_state = [[720, 400]]
742
  elif detail['value'] == "1:1":
743
+ orientation_state = [[720, 720]]
744
  elif detail['value'] == "16:9":
745
+ orientation_state = [[400, 720]]
746
+
747
+ print(f'{session_id} has {orientation_state} orientation')
748
 
749
+ return orientation_state
750
 
751
  def clear_raw_image():
752
  return ''
 
820
  with gr.Blocks(css=css) as demo:
821
 
822
  session_state = gr.State()
823
+ orientation_state = gr.State([[720, 400]])
824
  demo.load(start_session, outputs=[session_state])
825
 
826
 
 
938
  ],
939
  label="Image Samples",
940
  inputs=[image_input],
941
+ outputs=[image_input, raw_img_text],
942
+ fn=apply_image,
943
+ cache_examples=True
944
  )
945
 
946
  audio_examples = gr.Examples(
 
968
 
969
  infer_btn.click(
970
  fn=infer,
971
+ inputs=[image_input, audio_input, text_input, orientation_state, num_steps, session_state],
972
  outputs=[output_video]
973
  )
974
 
 
985
  inputs=[audio_input, limit_on, session_state],
986
  outputs=[audio_input],
987
  )
988
+ image_input.orientation(fn=orientation_changed, inputs=[session_state], outputs=[orientation_state]).then(fn=preprocess_img, inputs=[image_input, raw_img_text, orientation_state, session_state], outputs=[image_input, raw_img_text])
989
  image_input.clear(fn=clear_raw_image, outputs=[raw_img_text])
990
+ image_input.upload(fn=preprocess_img, inputs=[image_input, raw_img_text, orientation_state, session_state], outputs=[image_input, raw_img_text])
991
+ image_input.change(fn=update_generate_button, inputs=[image_input, audio_input, orientation_state, text_input, num_steps, session_state], outputs=[time_required])
992
+ audio_input.change(fn=update_generate_button, inputs=[image_input, audio_input, orientation_state, text_input, num_steps, session_state], outputs=[time_required])
993
+ num_steps.change(fn=slider_value_change, inputs=[image_input, audio_input, orientation_state, text_input, num_steps, session_state, adaptive_text], outputs=[time_required, text_input])
994
  adaptive_text.change(fn=check_box_clicked, inputs=[adaptive_text], outputs=[text_input])
995
  audio_input.upload(fn=apply_audio, inputs=[audio_input], outputs=[audio_input]
996
  ).then(