Spaces:
Runtime error
Runtime error
add control processing
Browse files
app.py
CHANGED
|
@@ -9,7 +9,7 @@ from pipeline_ltx_condition_control import LTXConditionPipeline
|
|
| 9 |
from diffusers.utils import export_to_video, load_video
|
| 10 |
from torchvision import transforms
|
| 11 |
import random
|
| 12 |
-
|
| 13 |
|
| 14 |
dtype = torch.bfloat16
|
| 15 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
@@ -20,6 +20,7 @@ pipeline.to(device)
|
|
| 20 |
pipe_upsample.to(device)
|
| 21 |
pipeline.vae.enable_tiling()
|
| 22 |
|
|
|
|
| 23 |
|
| 24 |
CONTROL_LORAS = {
|
| 25 |
"canny": {
|
|
@@ -39,11 +40,11 @@ CONTROL_LORAS = {
|
|
| 39 |
}
|
| 40 |
}
|
| 41 |
@spaces.GPU()
|
| 42 |
-
def read_video(
|
| 43 |
"""
|
| 44 |
Reads a video file and converts it into a torch.Tensor with the shape [F, C, H, W].
|
| 45 |
"""
|
| 46 |
-
|
| 47 |
to_tensor_transform = transforms.ToTensor()
|
| 48 |
video_tensor = torch.stack([to_tensor_transform(img) for img in pil_images])
|
| 49 |
return video_tensor
|
|
@@ -89,17 +90,19 @@ def load_control_lora(control_type, current_lora_state):
|
|
| 89 |
print(f"Error loading {control_type} LoRA: {e}")
|
| 90 |
raise
|
| 91 |
|
| 92 |
-
def process_video_for_canny(
|
| 93 |
"""
|
| 94 |
Process video for canny control.
|
| 95 |
-
Placeholder function - will return video as-is for now.
|
| 96 |
-
TODO: Implement canny edge detection processing
|
| 97 |
"""
|
| 98 |
print("Processing video for canny control...")
|
| 99 |
-
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
-
def process_video_for_depth(
|
| 103 |
"""
|
| 104 |
Process video for depth control.
|
| 105 |
Placeholder function - will return video as-is for now.
|
|
@@ -109,7 +112,7 @@ def process_video_for_depth(video_tensor):
|
|
| 109 |
|
| 110 |
return video_tensor
|
| 111 |
|
| 112 |
-
def process_video_for_pose(
|
| 113 |
"""
|
| 114 |
Process video for pose control.
|
| 115 |
Placeholder function - will return video as-is for now.
|
|
@@ -119,16 +122,16 @@ def process_video_for_pose(video_tensor):
|
|
| 119 |
|
| 120 |
return video_tensor
|
| 121 |
|
| 122 |
-
def process_video_for_control(
|
| 123 |
"""Process video based on the selected control type"""
|
| 124 |
if control_type == "canny":
|
| 125 |
-
return process_video_for_canny(
|
| 126 |
elif control_type == "depth":
|
| 127 |
-
return process_video_for_depth(
|
| 128 |
elif control_type == "pose":
|
| 129 |
-
return process_video_for_pose(
|
| 130 |
else:
|
| 131 |
-
return
|
| 132 |
|
| 133 |
@spaces.GPU(duration=120)
|
| 134 |
def generate_video(
|
|
@@ -169,15 +172,13 @@ def generate_video(
|
|
| 169 |
# Load the appropriate control LoRA and update state
|
| 170 |
updated_lora_state = load_control_lora(control_type, current_lora_state)
|
| 171 |
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
video = read_video(reference_video)
|
| 176 |
-
|
| 177 |
-
progress(0.15, desc="Processing video for control...")
|
| 178 |
|
| 179 |
# Process video based on control type
|
| 180 |
processed_video = process_video_for_control(video, control_type)
|
|
|
|
| 181 |
|
| 182 |
progress(0.2, desc="Preparing generation parameters...")
|
| 183 |
|
|
|
|
| 9 |
from diffusers.utils import export_to_video, load_video
|
| 10 |
from torchvision import transforms
|
| 11 |
import random
|
| 12 |
+
from controlnet_aux import CannyDetector
|
| 13 |
|
| 14 |
dtype = torch.bfloat16
|
| 15 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 20 |
pipe_upsample.to(device)
|
| 21 |
pipeline.vae.enable_tiling()
|
| 22 |
|
| 23 |
+
canny_processor = CannyDetector()
|
| 24 |
|
| 25 |
CONTROL_LORAS = {
|
| 26 |
"canny": {
|
|
|
|
| 40 |
}
|
| 41 |
}
|
| 42 |
@spaces.GPU()
|
| 43 |
+
def read_video(video) -> torch.Tensor:
|
| 44 |
"""
|
| 45 |
Reads a video file and converts it into a torch.Tensor with the shape [F, C, H, W].
|
| 46 |
"""
|
| 47 |
+
|
| 48 |
to_tensor_transform = transforms.ToTensor()
|
| 49 |
video_tensor = torch.stack([to_tensor_transform(img) for img in pil_images])
|
| 50 |
return video_tensor
|
|
|
|
| 90 |
print(f"Error loading {control_type} LoRA: {e}")
|
| 91 |
raise
|
| 92 |
|
| 93 |
+
def process_video_for_canny(video):
|
| 94 |
"""
|
| 95 |
Process video for canny control.
|
|
|
|
|
|
|
| 96 |
"""
|
| 97 |
print("Processing video for canny control...")
|
| 98 |
+
canny_video = []
|
| 99 |
+
for frame in video:
|
| 100 |
+
# TODO: change resolution logic
|
| 101 |
+
canny_video.append(processor(frame, low_threshold=50, high_threshold=200, detect_resolution=1024, image_resolution=1024))
|
| 102 |
+
|
| 103 |
+
return canny_video
|
| 104 |
|
| 105 |
+
def process_video_for_depth(video):
|
| 106 |
"""
|
| 107 |
Process video for depth control.
|
| 108 |
Placeholder function - will return video as-is for now.
|
|
|
|
| 112 |
|
| 113 |
return video_tensor
|
| 114 |
|
| 115 |
+
def process_video_for_pose(video):
|
| 116 |
"""
|
| 117 |
Process video for pose control.
|
| 118 |
Placeholder function - will return video as-is for now.
|
|
|
|
| 122 |
|
| 123 |
return video_tensor
|
| 124 |
|
| 125 |
+
def process_video_for_control(video, control_type):
|
| 126 |
"""Process video based on the selected control type"""
|
| 127 |
if control_type == "canny":
|
| 128 |
+
return process_video_for_canny(video)
|
| 129 |
elif control_type == "depth":
|
| 130 |
+
return process_video_for_depth(video)
|
| 131 |
elif control_type == "pose":
|
| 132 |
+
return process_video_for_pose(video)
|
| 133 |
else:
|
| 134 |
+
return video
|
| 135 |
|
| 136 |
@spaces.GPU(duration=120)
|
| 137 |
def generate_video(
|
|
|
|
| 172 |
# Load the appropriate control LoRA and update state
|
| 173 |
updated_lora_state = load_control_lora(control_type, current_lora_state)
|
| 174 |
|
| 175 |
+
# Loads video into a list of pil images
|
| 176 |
+
video = load_video(reference_video)
|
| 177 |
+
progress(0.1, desc="Processing video for control...")
|
|
|
|
|
|
|
|
|
|
| 178 |
|
| 179 |
# Process video based on control type
|
| 180 |
processed_video = process_video_for_control(video, control_type)
|
| 181 |
+
processed_video = read_video(processed_video) # turns to tensor
|
| 182 |
|
| 183 |
progress(0.2, desc="Preparing generation parameters...")
|
| 184 |
|