fffiloni's picture
Update gradio_app.py
13edd1a verified
raw
history blame
5.79 kB
import os
import gradio as gr
import torch
from huggingface_hub import snapshot_download
from diffusers.utils import load_image, export_to_video
from diffusers import UNetSpatioTemporalConditionModel
from custom_diffusers.pipelines.pipeline_frame_interpolation_with_noise_injection import FrameInterpolationWithNoiseInjectionPipeline
from custom_diffusers.schedulers.scheduling_euler_discrete import EulerDiscreteScheduler
from attn_ctrl.attention_control import (AttentionStore,
register_temporal_self_attention_control,
register_temporal_self_attention_flip_control,
)
from torch.cuda.amp import autocast
import gc
# Set PYTORCH_CUDA_ALLOC_CONF
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
# Set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Download checkpoint
snapshot_download(repo_id="fffiloni/svd_keyframe_interpolation", local_dir="checkpoints")
checkpoint_dir = "checkpoints/svd_reverse_motion_with_attnflip"
# Initialize pipeline
pretrained_model_name_or_path = "stabilityai/stable-video-diffusion-img2vid-xt"
noise_scheduler = EulerDiscreteScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
pipe = FrameInterpolationWithNoiseInjectionPipeline.from_pretrained(
pretrained_model_name_or_path,
scheduler=noise_scheduler,
variant="fp16",
torch_dtype=torch.float16,
)
ref_unet = pipe.ori_unet
# Compute delta w
state_dict = pipe.unet.state_dict()
finetuned_unet = UNetSpatioTemporalConditionModel.from_pretrained(
checkpoint_dir,
subfolder="unet",
torch_dtype=torch.float16,
)
assert finetuned_unet.config.num_frames == 14
ori_unet = UNetSpatioTemporalConditionModel.from_pretrained(
"stabilityai/stable-video-diffusion-img2vid",
subfolder="unet",
variant='fp16',
torch_dtype=torch.float16,
)
finetuned_state_dict = finetuned_unet.state_dict()
ori_state_dict = ori_unet.state_dict()
for name, param in finetuned_state_dict.items():
if 'temporal_transformer_blocks.0.attn1.to_v' in name or "temporal_transformer_blocks.0.attn1.to_out.0" in name:
delta_w = param - ori_state_dict[name]
state_dict[name] = state_dict[name] + delta_w
pipe.unet.load_state_dict(state_dict)
controller_ref = AttentionStore()
register_temporal_self_attention_control(ref_unet, controller_ref)
controller = AttentionStore()
register_temporal_self_attention_flip_control(pipe.unet, controller, controller_ref)
# Custom CUDA memory management function
def cuda_memory_cleanup():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
gc.collect()
def check_outputs_folder(folder_path):
if os.path.exists(folder_path) and os.path.isdir(folder_path):
for filename in os.listdir(folder_path):
file_path = os.path.join(folder_path, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print(f'Failed to delete {file_path}. Reason: {e}')
else:
print(f'The folder {folder_path} does not exist.')
@torch.no_grad()
def infer(frame1_path, frame2_path):
seed = 42
num_inference_steps = 5 # Reduced from 10
noise_injection_steps = 0
noise_injection_ratio = 0.5
weighted_average = False
generator = torch.Generator(device)
if seed is not None:
generator = generator.manual_seed(seed)
frame1 = load_image(frame1_path)
frame1 = frame1.resize((256, 144)) # Reduced from (512, 288)
frame2 = load_image(frame2_path)
frame2 = frame2.resize((256, 144)) # Reduced from (512, 288)
# Clear CUDA cache
cuda_memory_cleanup()
# Move model to CPU and clear CUDA cache
pipe.to("cpu")
cuda_memory_cleanup()
# Move model back to GPU
pipe.to(device)
try:
with autocast(device_type='cuda', dtype=torch.float16):
frames = pipe(
image1=frame1,
image2=frame2,
num_inference_steps=num_inference_steps,
generator=generator,
weighted_average=weighted_average,
noise_injection_steps=noise_injection_steps,
noise_injection_ratio=noise_injection_ratio,
).frames[0]
frames = [frame.cpu() for frame in frames]
out_dir = "result"
check_outputs_folder(out_dir)
os.makedirs(out_dir, exist_ok=True)
out_path = "result/video_result.gif"
return "done"
except RuntimeError as e:
if "CUDA out of memory" in str(e):
return "Error: CUDA out of memory. Try reducing the image size or using fewer inference steps."
else:
return f"An error occurred: {str(e)}"
finally:
# Move model back to CPU and clear CUDA cache
pipe.to("cpu")
cuda_memory_cleanup()
@torch.no_grad()
def load_model():
global pipe
pipe = pipe.to(device)
with gr.Blocks() as demo:
with gr.Column():
gr.Markdown("# Keyframe Interpolation with Stable Video Diffusion")
with gr.Row():
with gr.Column():
image_input1 = gr.Image(type="filepath")
image_input2 = gr.Image(type="filepath")
submit_btn = gr.Button("Submit")
with gr.Column():
output = gr.Textbox()
submit_btn.click(
fn=infer,
inputs=[image_input1, image_input2],
outputs=[output],
show_api=False
)
demo.load(load_model)
demo.queue(max_size=1).launch(show_api=False, show_error=True)