Spaces:
Running
on
Zero
Running
on
Zero
Update hf_gradio_app.py
Browse files- hf_gradio_app.py +17 -18
hf_gradio_app.py
CHANGED
@@ -65,24 +65,22 @@ from memo.utils.vision_utils import preprocess_image, tensor_to_video
|
|
65 |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
66 |
weight_dtype = torch.bfloat16
|
67 |
|
68 |
-
with torch.inference_mode():
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
pipeline = VideoPipeline(vae=vae, reference_net=reference_net, diffusion_net=diffusion_net, scheduler=noise_scheduler, image_proj=image_proj)
|
85 |
-
pipeline.to(device=device, dtype=weight_dtype)
|
86 |
|
87 |
def process_audio(file_path, temp_dir):
|
88 |
# Load the audio file
|
@@ -104,6 +102,7 @@ def process_audio(file_path, temp_dir):
|
|
104 |
#@torch.inference_mode()
|
105 |
@spaces.GPU(duration=200)
|
106 |
def generate(input_video, input_audio, seed, progress=gr.Progress(track_tqdm=True)):
|
|
|
107 |
|
108 |
is_shared_ui = True if "fffiloni/MEMO" in os.environ['SPACE_ID'] else False
|
109 |
temp_dir = None
|
|
|
65 |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
66 |
weight_dtype = torch.bfloat16
|
67 |
|
68 |
+
#with torch.inference_mode():
|
69 |
+
vae = AutoencoderKL.from_pretrained("./checkpoints/vae").to(device=device, dtype=weight_dtype)
|
70 |
+
reference_net = UNet2DConditionModel.from_pretrained("./checkpoints", subfolder="reference_net", use_safetensors=True)
|
71 |
+
diffusion_net = UNet3DConditionModel.from_pretrained("./checkpoints", subfolder="diffusion_net", use_safetensors=True)
|
72 |
+
image_proj = ImageProjModel.from_pretrained("./checkpoints", subfolder="image_proj", use_safetensors=True)
|
73 |
+
audio_proj = AudioProjModel.from_pretrained("./checkpoints", subfolder="audio_proj", use_safetensors=True)
|
74 |
+
vae.requires_grad_(False).eval()
|
75 |
+
reference_net.requires_grad_(False).eval()
|
76 |
+
diffusion_net.requires_grad_(False).eval()
|
77 |
+
image_proj.requires_grad_(False).eval()
|
78 |
+
audio_proj.requires_grad_(False).eval()
|
79 |
+
#reference_net.enable_xformers_memory_efficient_attention()
|
80 |
+
#diffusion_net.enable_xformers_memory_efficient_attention()
|
81 |
+
noise_scheduler = FlowMatchEulerDiscreteScheduler()
|
82 |
+
pipeline = VideoPipeline(vae=vae, reference_net=reference_net, diffusion_net=diffusion_net, scheduler=noise_scheduler, image_proj=image_proj)
|
83 |
+
#pipeline.to(device=device, dtype=weight_dtype)
|
|
|
|
|
84 |
|
85 |
def process_audio(file_path, temp_dir):
|
86 |
# Load the audio file
|
|
|
102 |
#@torch.inference_mode()
|
103 |
@spaces.GPU(duration=200)
|
104 |
def generate(input_video, input_audio, seed, progress=gr.Progress(track_tqdm=True)):
|
105 |
+
pipeline.to(device=device, dtype=weight_dtype)
|
106 |
|
107 |
is_shared_ui = True if "fffiloni/MEMO" in os.environ['SPACE_ID'] else False
|
108 |
temp_dir = None
|