fffiloni commited on
Commit
7f9b687
·
verified ·
1 Parent(s): 13edd1a

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +6 -22
gradio_app.py CHANGED
@@ -10,7 +10,7 @@ from attn_ctrl.attention_control import (AttentionStore,
10
  register_temporal_self_attention_control,
11
  register_temporal_self_attention_flip_control,
12
  )
13
- from torch.cuda.amp import autocast
14
  import gc
15
 
16
  # Set PYTORCH_CUDA_ALLOC_CONF
@@ -32,7 +32,7 @@ pipe = FrameInterpolationWithNoiseInjectionPipeline.from_pretrained(
32
  scheduler=noise_scheduler,
33
  variant="fp16",
34
  torch_dtype=torch.float16,
35
- )
36
  ref_unet = pipe.ori_unet
37
 
38
  # Compute delta w
@@ -41,14 +41,14 @@ finetuned_unet = UNetSpatioTemporalConditionModel.from_pretrained(
41
  checkpoint_dir,
42
  subfolder="unet",
43
  torch_dtype=torch.float16,
44
- )
45
  assert finetuned_unet.config.num_frames == 14
46
  ori_unet = UNetSpatioTemporalConditionModel.from_pretrained(
47
  "stabilityai/stable-video-diffusion-img2vid",
48
  subfolder="unet",
49
  variant='fp16',
50
  torch_dtype=torch.float16,
51
- )
52
 
53
  finetuned_state_dict = finetuned_unet.state_dict()
54
  ori_state_dict = ori_unet.state_dict()
@@ -105,15 +105,8 @@ def infer(frame1_path, frame2_path):
105
  # Clear CUDA cache
106
  cuda_memory_cleanup()
107
 
108
- # Move model to CPU and clear CUDA cache
109
- pipe.to("cpu")
110
- cuda_memory_cleanup()
111
-
112
- # Move model back to GPU
113
- pipe.to(device)
114
-
115
  try:
116
- with autocast(device_type='cuda', dtype=torch.float16):
117
  frames = pipe(
118
  image1=frame1,
119
  image2=frame2,
@@ -138,15 +131,8 @@ def infer(frame1_path, frame2_path):
138
  else:
139
  return f"An error occurred: {str(e)}"
140
  finally:
141
- # Move model back to CPU and clear CUDA cache
142
- pipe.to("cpu")
143
  cuda_memory_cleanup()
144
 
145
- @torch.no_grad()
146
- def load_model():
147
- global pipe
148
- pipe = pipe.to(device)
149
-
150
  with gr.Blocks() as demo:
151
  with gr.Column():
152
  gr.Markdown("# Keyframe Interpolation with Stable Video Diffusion")
@@ -165,6 +151,4 @@ with gr.Blocks() as demo:
165
  show_api=False
166
  )
167
 
168
- demo.load(load_model)
169
-
170
- demo.queue(max_size=1).launch(show_api=False, show_error=True)
 
10
  register_temporal_self_attention_control,
11
  register_temporal_self_attention_flip_control,
12
  )
13
+ from torch.amp import autocast
14
  import gc
15
 
16
  # Set PYTORCH_CUDA_ALLOC_CONF
 
32
  scheduler=noise_scheduler,
33
  variant="fp16",
34
  torch_dtype=torch.float16,
35
+ ).to(device)
36
  ref_unet = pipe.ori_unet
37
 
38
  # Compute delta w
 
41
  checkpoint_dir,
42
  subfolder="unet",
43
  torch_dtype=torch.float16,
44
+ ).to(device)
45
  assert finetuned_unet.config.num_frames == 14
46
  ori_unet = UNetSpatioTemporalConditionModel.from_pretrained(
47
  "stabilityai/stable-video-diffusion-img2vid",
48
  subfolder="unet",
49
  variant='fp16',
50
  torch_dtype=torch.float16,
51
+ ).to(device)
52
 
53
  finetuned_state_dict = finetuned_unet.state_dict()
54
  ori_state_dict = ori_unet.state_dict()
 
105
  # Clear CUDA cache
106
  cuda_memory_cleanup()
107
 
 
 
 
 
 
 
 
108
  try:
109
+ with autocast():
110
  frames = pipe(
111
  image1=frame1,
112
  image2=frame2,
 
131
  else:
132
  return f"An error occurred: {str(e)}"
133
  finally:
 
 
134
  cuda_memory_cleanup()
135
 
 
 
 
 
 
136
  with gr.Blocks() as demo:
137
  with gr.Column():
138
  gr.Markdown("# Keyframe Interpolation with Stable Video Diffusion")
 
151
  show_api=False
152
  )
153
 
154
+ demo.queue(max_size=1).launch(show_api=False, show_error=True, share=True)