multimodalart HF Staff commited on
Commit
6243da9
·
verified ·
1 Parent(s): 71c64ba

get rid of gr.State

Browse files
Files changed (1) hide show
  1. app.py +27 -12
app.py CHANGED
@@ -143,7 +143,7 @@ def get_duration(prompt, negative_prompt, input_image_filepath, input_video_file
143
  @spaces.GPU(duration=get_duration)
144
  def generate(prompt, negative_prompt, input_image_filepath, input_video_filepath,
145
  height_ui, width_ui, mode,
146
- duration_ui, # Removed ui_steps
147
  ui_frames_to_use,
148
  seed_ui, randomize_seed, ui_guidance_scale, improve_texture_flag,
149
  progress=gr.Progress(track_tqdm=True)):
@@ -321,6 +321,15 @@ def generate(prompt, negative_prompt, input_image_filepath, input_video_filepath
321
 
322
  return output_video_path, seed_ui
323
 
 
 
 
 
 
 
 
 
 
324
 
325
  # --- Gradio UI Definition ---
326
  css="""
@@ -368,15 +377,13 @@ with gr.Blocks(css=css) as demo:
368
  gr.DeepLinkButton()
369
 
370
  with gr.Accordion("Advanced settings", open=False):
 
371
  negative_prompt_input = gr.Textbox(label="Negative Prompt", value="worst quality, inconsistent motion, blurry, jittery, distorted", lines=2)
372
  with gr.Row():
373
  seed_input = gr.Number(label="Seed", value=42, precision=0, minimum=0, maximum=2**32-1)
374
  randomize_seed_input = gr.Checkbox(label="Randomize Seed", value=True)
375
  with gr.Row():
376
  guidance_scale_input = gr.Slider(label="Guidance Scale (CFG)", minimum=1.0, maximum=10.0, value=PIPELINE_CONFIG_YAML.get("first_pass", {}).get("guidance_scale", 1.0), step=0.1, info="Controls how much the prompt influences the output. Higher values = stronger influence.")
377
- # Removed steps_input slider
378
- # default_steps = len(PIPELINE_CONFIG_YAML.get("first_pass", {}).get("timesteps", [1]*7))
379
- # steps_input = gr.Slider(label="Inference Steps (for first pass if multi-scale)", minimum=1, maximum=30, value=default_steps, step=1, info="Number of denoising steps. More steps can improve quality but increase time. If YAML defines 'timesteps' for a pass, this UI value is ignored for that pass.")
380
  with gr.Row():
381
  height_input = gr.Slider(label="Height", value=512, step=32, minimum=MIN_DIM_SLIDER, maximum=MAX_IMAGE_SIZE, info="Must be divisible by 32.")
382
  width_input = gr.Slider(label="Width", value=704, step=32, minimum=MIN_DIM_SLIDER, maximum=MAX_IMAGE_SIZE, info="Must be divisible by 32.")
@@ -433,7 +440,7 @@ with gr.Blocks(css=css) as demo:
433
  print(f"Error processing video for dimension update: {e} (Path: {video_filepath}, Type: {type(video_filepath)})")
434
  return gr.update(value=current_h), gr.update(value=current_w)
435
 
436
- # Attach upload handlers
437
  image_i2v.upload(
438
  fn=handle_image_upload_for_dims,
439
  inputs=[image_i2v, height_input, width_input],
@@ -444,21 +451,29 @@ with gr.Blocks(css=css) as demo:
444
  inputs=[video_v2v, height_input, width_input],
445
  outputs=[height_input, width_input]
446
  )
 
 
 
 
 
 
 
 
 
447
 
448
- # --- INPUT LISTS (remain the same structurally) ---
449
  t2v_inputs = [t2v_prompt, negative_prompt_input, image_n_hidden, video_n_hidden,
450
- height_input, width_input, gr.State("text-to-video"),
451
- duration_input, gr.State(0), # Removed steps_input
452
  seed_input, randomize_seed_input, guidance_scale_input, improve_texture]
453
 
454
  i2v_inputs = [i2v_prompt, negative_prompt_input, image_i2v, video_i_hidden,
455
- height_input, width_input, gr.State("image-to-video"),
456
- duration_input, gr.State(0), # Removed steps_input
457
  seed_input, randomize_seed_input, guidance_scale_input, improve_texture]
458
 
459
  v2v_inputs = [v2v_prompt, negative_prompt_input, image_v_hidden, video_v2v,
460
- height_input, width_input, gr.State("video-to-video"),
461
- duration_input, frames_to_use, # Removed steps_input
462
  seed_input, randomize_seed_input, guidance_scale_input, improve_texture]
463
 
464
  t2v_button.click(fn=generate, inputs=t2v_inputs, outputs=[output_video, seed_input], api_name="text_to_video")
 
143
  @spaces.GPU(duration=get_duration)
144
  def generate(prompt, negative_prompt, input_image_filepath, input_video_filepath,
145
  height_ui, width_ui, mode,
146
+ duration_ui,
147
  ui_frames_to_use,
148
  seed_ui, randomize_seed, ui_guidance_scale, improve_texture_flag,
149
  progress=gr.Progress(track_tqdm=True)):
 
321
 
322
  return output_video_path, seed_ui
323
 
324
+ def update_task_image{
325
+ return "image-to-video"
326
+ }
327
+ def update_task_text{
328
+ return "text-to-video"
329
+ }
330
+ def update_task_video{
331
+ return "video-to-video"
332
+ }
333
 
334
  # --- Gradio UI Definition ---
335
  css="""
 
377
  gr.DeepLinkButton()
378
 
379
  with gr.Accordion("Advanced settings", open=False):
380
+ mode = gr.Dropdown(["text-to-video", "image-to-video", "video-to-video"], label="task", visible=False)
381
  negative_prompt_input = gr.Textbox(label="Negative Prompt", value="worst quality, inconsistent motion, blurry, jittery, distorted", lines=2)
382
  with gr.Row():
383
  seed_input = gr.Number(label="Seed", value=42, precision=0, minimum=0, maximum=2**32-1)
384
  randomize_seed_input = gr.Checkbox(label="Randomize Seed", value=True)
385
  with gr.Row():
386
  guidance_scale_input = gr.Slider(label="Guidance Scale (CFG)", minimum=1.0, maximum=10.0, value=PIPELINE_CONFIG_YAML.get("first_pass", {}).get("guidance_scale", 1.0), step=0.1, info="Controls how much the prompt influences the output. Higher values = stronger influence.")
 
 
 
387
  with gr.Row():
388
  height_input = gr.Slider(label="Height", value=512, step=32, minimum=MIN_DIM_SLIDER, maximum=MAX_IMAGE_SIZE, info="Must be divisible by 32.")
389
  width_input = gr.Slider(label="Width", value=704, step=32, minimum=MIN_DIM_SLIDER, maximum=MAX_IMAGE_SIZE, info="Must be divisible by 32.")
 
440
  print(f"Error processing video for dimension update: {e} (Path: {video_filepath}, Type: {type(video_filepath)})")
441
  return gr.update(value=current_h), gr.update(value=current_w)
442
 
443
+
444
  image_i2v.upload(
445
  fn=handle_image_upload_for_dims,
446
  inputs=[image_i2v, height_input, width_input],
 
451
  inputs=[video_v2v, height_input, width_input],
452
  outputs=[height_input, width_input]
453
  )
454
+
455
+ image_tab.select(
456
+ fn=update_task_image,
457
+ outputs=[mode]
458
+ )
459
+ text_tab.select(
460
+ fn=update_task_text,
461
+ outputs=[mode]
462
+ )
463
 
 
464
  t2v_inputs = [t2v_prompt, negative_prompt_input, image_n_hidden, video_n_hidden,
465
+ height_input, width_input, mode,
466
+ duration_input, frames_to_use,
467
  seed_input, randomize_seed_input, guidance_scale_input, improve_texture]
468
 
469
  i2v_inputs = [i2v_prompt, negative_prompt_input, image_i2v, video_i_hidden,
470
+ height_input, width_input, mode,
471
+ duration_input, frames_to_use,
472
  seed_input, randomize_seed_input, guidance_scale_input, improve_texture]
473
 
474
  v2v_inputs = [v2v_prompt, negative_prompt_input, image_v_hidden, video_v2v,
475
+ height_input, width_input, mode,
476
+ duration_input, frames_to_use,
477
  seed_input, randomize_seed_input, guidance_scale_input, improve_texture]
478
 
479
  t2v_button.click(fn=generate, inputs=t2v_inputs, outputs=[output_video, seed_input], api_name="text_to_video")