alexnasa commited on
Commit
5af1c18
·
verified ·
1 Parent(s): 6b2fdf5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -23
app.py CHANGED
@@ -321,6 +321,8 @@ class WanInferencePipeline(nn.Module):
321
  image = self.transform(image).unsqueeze(0).to(dtype=self.dtype)
322
 
323
  _, _, h, w = image.shape
 
 
324
  image = image * 2.0 - 1.0
325
  image = image[:, :, None]
326
 
@@ -328,7 +330,7 @@ class WanInferencePipeline(nn.Module):
328
  image = None
329
  select_size = [height, width]
330
  num = self.args.max_tokens * 16 * 16 * 4
331
- den = h * w
332
  L0 = num // den
333
  diff = (L0 - 1) % 4
334
  L = L0 - diff
@@ -392,22 +394,29 @@ class WanInferencePipeline(nn.Module):
392
  image = self.transform(image).unsqueeze(0).to(self.device, dtype=self.dtype)
393
 
394
  _, _, h, w = image.shape
 
 
395
  image = image * 2.0 - 1.0
396
  image = image[:, :, None]
397
 
398
  else:
399
  image = None
400
- h = height
401
- w = width
 
 
402
 
403
  # step 1: numerator and denominator as ints
404
  num = args.max_tokens * 16 * 16 * 4
405
- den = h * w
406
 
407
  # step 2: integer division
408
  L0 = num // den # exact floor division, no float in sight
409
 
410
  # step 3: make it ≡ 1 mod 4
 
 
 
411
  diff = (L0 - 1) % 4
412
  L = L0 - diff
413
  if L < 1:
@@ -606,7 +615,7 @@ def get_duration(image_path, audio_path, text, num_steps, session_id, progress):
606
 
607
  return int(duration_s)
608
 
609
- def preprocess_img(input_image_path, raw_image_path, orientation_state, session_id = None):
610
 
611
  if session_id is None:
612
  session_id = uuid.uuid4().hex
@@ -622,7 +631,7 @@ def preprocess_img(input_image_path, raw_image_path, orientation_state, session_
622
  image = inferpipe.transform(image).unsqueeze(0).to(dtype=inferpipe.dtype)
623
 
624
  _, _, h, w = image.shape
625
- select_size = match_size(orientation_state, h, w)
626
  image = resize_pad(image, (h, w), select_size)
627
  image = image * 2.0 - 1.0
628
  image = image[:, :, None]
@@ -640,12 +649,13 @@ def preprocess_img(input_image_path, raw_image_path, orientation_state, session_
640
 
641
  def infer_example(image_path, audio_path, text, num_steps, raw_image_path, session_id = None, progress=gr.Progress(track_tqdm=True),):
642
 
643
- if session_id is None:
644
- session_id = uuid.uuid4().hex
645
 
646
- image_path, _ = preprocess_img(image_path, image_path, [[720, 400]], session_id)
647
  result = infer(image_path, audio_path, text, num_steps, session_id, progress)
648
 
 
 
649
  return result
650
 
651
  @spaces.GPU(duration=get_duration)
@@ -703,8 +713,7 @@ def infer(image_path, audio_path, text, num_steps, session_id = None, progress=g
703
 
704
  def apply_image(request):
705
  print('image applied')
706
-
707
- return request, request
708
 
709
  def apply_audio(request):
710
  print('audio applied')
@@ -730,15 +739,13 @@ def orientation_changed(session_id, evt: gr.EventData):
730
  detail = getattr(evt, "data", None) or getattr(evt, "_data", {}) or {}
731
 
732
  if detail['value'] == "9:16":
733
- orientation_state = [[720, 400]]
734
  elif detail['value'] == "1:1":
735
- orientation_state = [[720, 720]]
736
  elif detail['value'] == "16:9":
737
- orientation_state = [[400, 720]]
738
-
739
- print(f'{session_id} has {orientation_state} orientation')
740
 
741
- return orientation_state
742
 
743
  def clear_raw_image():
744
  return ''
@@ -812,7 +819,6 @@ css = """
812
  with gr.Blocks(css=css) as demo:
813
 
814
  session_state = gr.State()
815
- orientation_state = gr.State([[720, 400]])
816
  demo.load(start_session, outputs=[session_state])
817
 
818
 
@@ -930,9 +936,7 @@ with gr.Blocks(css=css) as demo:
930
  ],
931
  label="Image Samples",
932
  inputs=[image_input],
933
- outputs=[image_input, raw_img_text],
934
- fn=apply_image,
935
- cache_examples=True
936
  )
937
 
938
  audio_examples = gr.Examples(
@@ -977,9 +981,9 @@ with gr.Blocks(css=css) as demo:
977
  inputs=[audio_input, limit_on, session_state],
978
  outputs=[audio_input],
979
  )
980
- image_input.orientation(fn=orientation_changed, inputs=[session_state], outputs=[orientation_state]).then(fn=preprocess_img, inputs=[image_input, raw_img_text, orientation_state, session_state], outputs=[image_input, raw_img_text])
981
  image_input.clear(fn=clear_raw_image, outputs=[raw_img_text])
982
- image_input.upload(fn=preprocess_img, inputs=[image_input, raw_img_text, orientation_state, session_state], outputs=[image_input, raw_img_text])
983
  image_input.change(fn=update_generate_button, inputs=[image_input, audio_input, text_input, num_steps, session_state], outputs=[time_required])
984
  audio_input.change(fn=update_generate_button, inputs=[image_input, audio_input, text_input, num_steps, session_state], outputs=[time_required])
985
  num_steps.change(fn=slider_value_change, inputs=[image_input, audio_input, text_input, num_steps, session_state, adaptive_text], outputs=[time_required, text_input])
 
321
  image = self.transform(image).unsqueeze(0).to(dtype=self.dtype)
322
 
323
  _, _, h, w = image.shape
324
+ select_size = match_size(getattr( self.args, f'image_sizes_{ self.args.max_hw}'), h, w)
325
+ image = resize_pad(image, (h, w), select_size)
326
  image = image * 2.0 - 1.0
327
  image = image[:, :, None]
328
 
 
330
  image = None
331
  select_size = [height, width]
332
  num = self.args.max_tokens * 16 * 16 * 4
333
+ den = select_size[0] * select_size[1]
334
  L0 = num // den
335
  diff = (L0 - 1) % 4
336
  L = L0 - diff
 
394
  image = self.transform(image).unsqueeze(0).to(self.device, dtype=self.dtype)
395
 
396
  _, _, h, w = image.shape
397
+ select_size = match_size(getattr(self.args, f'image_sizes_{self.args.max_hw}'), h, w)
398
+ image = resize_pad(image, (h, w), select_size)
399
  image = image * 2.0 - 1.0
400
  image = image[:, :, None]
401
 
402
  else:
403
  image = None
404
+ select_size = [height, width]
405
+ # L = int(self.args.max_tokens * 16 * 16 * 4 / select_size[0] / select_size[1])
406
+ # L = L // 4 * 4 + 1 if L % 4 != 0 else L - 3 # video frames
407
+ # T = (L + 3) // 4 # latent frames
408
 
409
  # step 1: numerator and denominator as ints
410
  num = args.max_tokens * 16 * 16 * 4
411
+ den = select_size[0] * select_size[1]
412
 
413
  # step 2: integer division
414
  L0 = num // den # exact floor division, no float in sight
415
 
416
  # step 3: make it ≡ 1 mod 4
417
+ # if L0 % 4 == 1, keep L0;
418
+ # otherwise subtract the difference so that (L0 - diff) % 4 == 1,
419
+ # but ensure the result stays positive.
420
  diff = (L0 - 1) % 4
421
  L = L0 - diff
422
  if L < 1:
 
615
 
616
  return int(duration_s)
617
 
618
+ def preprocess_img(input_image_path, raw_image_path, session_id = None):
619
 
620
  if session_id is None:
621
  session_id = uuid.uuid4().hex
 
631
  image = inferpipe.transform(image).unsqueeze(0).to(dtype=inferpipe.dtype)
632
 
633
  _, _, h, w = image.shape
634
+ select_size = match_size(getattr( args, f'image_sizes_{ args.max_hw}'), h, w)
635
  image = resize_pad(image, (h, w), select_size)
636
  image = image * 2.0 - 1.0
637
  image = image[:, :, None]
 
649
 
650
  def infer_example(image_path, audio_path, text, num_steps, raw_image_path, session_id = None, progress=gr.Progress(track_tqdm=True),):
651
 
652
+ current_image_size = args.image_sizes_720
653
+ args.image_sizes_720 = [[720, 400]]
654
 
 
655
  result = infer(image_path, audio_path, text, num_steps, session_id, progress)
656
 
657
+ args.image_sizes_720 = current_image_size
658
+
659
  return result
660
 
661
  @spaces.GPU(duration=get_duration)
 
713
 
714
  def apply_image(request):
715
  print('image applied')
716
+ return request, None
 
717
 
718
  def apply_audio(request):
719
  print('audio applied')
 
739
  detail = getattr(evt, "data", None) or getattr(evt, "_data", {}) or {}
740
 
741
  if detail['value'] == "9:16":
742
+ args.image_sizes_720 = [[720, 400]]
743
  elif detail['value'] == "1:1":
744
+ args.image_sizes_720 = [[720, 720]]
745
  elif detail['value'] == "16:9":
746
+ args.image_sizes_720 = [[400, 720]]
 
 
747
 
748
+ print(f'{session_id} has {args.image_sizes_720} orientation')
749
 
750
  def clear_raw_image():
751
  return ''
 
819
  with gr.Blocks(css=css) as demo:
820
 
821
  session_state = gr.State()
 
822
  demo.load(start_session, outputs=[session_state])
823
 
824
 
 
936
  ],
937
  label="Image Samples",
938
  inputs=[image_input],
939
+ cache_examples=False
 
 
940
  )
941
 
942
  audio_examples = gr.Examples(
 
981
  inputs=[audio_input, limit_on, session_state],
982
  outputs=[audio_input],
983
  )
984
+ image_input.orientation(fn=orientation_changed, inputs=[session_state]).then(fn=preprocess_img, inputs=[image_input, raw_img_text, session_state], outputs=[image_input, raw_img_text])
985
  image_input.clear(fn=clear_raw_image, outputs=[raw_img_text])
986
+ image_input.upload(fn=preprocess_img, inputs=[image_input, raw_img_text, session_state], outputs=[image_input, raw_img_text])
987
  image_input.change(fn=update_generate_button, inputs=[image_input, audio_input, text_input, num_steps, session_state], outputs=[time_required])
988
  audio_input.change(fn=update_generate_button, inputs=[image_input, audio_input, text_input, num_steps, session_state], outputs=[time_required])
989
  num_steps.change(fn=slider_value_change, inputs=[image_input, audio_input, text_input, num_steps, session_state, adaptive_text], outputs=[time_required, text_input])