File size: 9,445 Bytes
ef058d8 a203df4 ef058d8 512f6a4 a203df4 ef058d8 a203df4 ef058d8 a203df4 7cd9249 a203df4 4fdc33b a203df4 ef058d8 a203df4 ef058d8 39c490f ef058d8 a203df4 ef058d8 a203df4 512f6a4 a203df4 ef058d8 a203df4 39c490f a203df4 39c490f a203df4 39c490f a203df4 d1cf1a0 a203df4 d1cf1a0 a203df4 d1cf1a0 a203df4 d1cf1a0 a203df4 39c490f a203df4 ef058d8 a203df4 490f3d3 a203df4 490f3d3 a203df4 490f3d3 a203df4 490f3d3 a203df4 4cfb1f8 a203df4 490f3d3 4cfb1f8 a203df4 4cfb1f8 a203df4 ef058d8 a203df4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 |
import gradio as gr
from diffusers import CogVideoXTransformer3DModel, DiffusionPipeline
from diffusers.utils import export_to_video
import torch
import tempfile
import os
import spaces
# Available transformer models
TRANSFORMER_MODELS = [
"sayakpaul/pika-dissolve-v0",
"finetrainers/crush-smol-v0",
"finetrainers/3dgs-v0",
"finetrainers/cakeify-v0"
]
# Model prefix mapping
MODEL_PREFIXES = {
"sayakpaul/pika-dissolve-v0": "PIKA_DISSOLVE",
"finetrainers/crush-smol-v0": "DIFF_crush",
"finetrainers/3dgs-v0": "3D_dissolve",
"finetrainers/cakeify-v0": "PIKA_CAKEIFY"
}
def check_and_fix_prompt(transformer_model, prompt):
"""Check and fix prompt according to model requirements"""
required_prefix = MODEL_PREFIXES.get(transformer_model)
if not required_prefix:
print(f"No required prefix found for model: {transformer_model}")
return prompt
# Remove leading/trailing whitespace
prompt = prompt.strip()
# Check if prompt already starts with the required prefix
if not prompt.startswith(required_prefix):
print(f"Adding required prefix '{required_prefix}' to prompt")
prompt = f"{required_prefix} {prompt}"
return prompt
def load_models(transformer_model):
"""Load transformer and pipeline models"""
# Load the selected transformer model
print(f"Loading model: {transformer_model}")
transformer = CogVideoXTransformer3DModel.from_pretrained(
transformer_model,
torch_dtype=torch.bfloat16
)
# Initialize the pipeline
print("Initializing pipeline")
pipeline = DiffusionPipeline.from_pretrained(
"THUDM/CogVideoX-5b",
transformer=transformer,
torch_dtype=torch.bfloat16
)
return pipeline
def save_video(video_frames, fps=25):
"""Save video frames to a temporary file"""
print("Saving video")
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file:
export_to_video(video_frames, tmp_file.name, fps=fps)
return tmp_file.name
@spaces.GPU(duration=600)
def generate_video_pipeline(pipeline, prompt, negative_prompt, num_frames, height, width, num_inference_steps):
"""Generate video using the pipeline"""
print(f"Is CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
# Move to appropriate device
print("Moving to device")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipeline = pipeline.to(device)
# Generate video
print("Generating video")
video_frames = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
num_frames=num_frames,
height=height,
width=width,
num_inference_steps=num_inference_steps
).frames[0]
print("Video generated")
return video_frames
def generate_video(transformer_model, prompt, negative_prompt, num_frames, height, width, num_inference_steps):
"""Main function to handle the video generation process"""
# Check and fix prompt
print(f"Original prompt: {prompt}")
prompt = check_and_fix_prompt(transformer_model, prompt)
print(f"Final prompt: {prompt}")
# Load models
pipeline = load_models(transformer_model)
# Generate video frames
video_frames = generate_video_pipeline(
pipeline,
prompt,
negative_prompt,
num_frames,
height,
width,
num_inference_steps
)
# Save and return video path
print("Saving video")
return save_video(video_frames)
def create_interface():
"""Create and configure the Gradio interface"""
with gr.Blocks() as demo:
gr.Markdown("# CogVideoX Video Generator")
with gr.Row():
with gr.Column():
# Inputs
model_dropdown = gr.Dropdown(
choices=TRANSFORMER_MODELS,
value=TRANSFORMER_MODELS[0],
label="Transformer Model"
)
prompt_input = gr.Textbox(
lines=5,
label="Prompt",
placeholder="Describe the video you want to generate..."
)
negative_prompt_input = gr.Textbox(
lines=2,
label="Negative Prompt",
value="inconsistent motion, blurry motion, worse quality, degenerate outputs, deformed outputs"
)
with gr.Accordion("Advanced Parameters", open=False):
num_frames = gr.Slider(
minimum=8,
maximum=128,
value=50,
step=1,
label="Number of Frames",
info="Number of frames in the video"
)
height = gr.Slider(
minimum=32,
maximum=1024,
value=512,
step=64,
label="Height",
info="Video height in pixels"
)
width = gr.Slider(
minimum=32,
maximum=1024,
value=512,
step=64,
label="Width",
info="Video width in pixels"
)
num_inference_steps = gr.Slider(
minimum=10,
maximum=100,
value=50,
step=1,
label="Inference Steps",
info="Higher number = better quality but slower"
)
generate_btn = gr.Button("Generate Video")
with gr.Column():
# Output
video_output = gr.Video(label="Generated Video")
# Add examples
gr.Examples(
examples=[
[
"sayakpaul/pika-dissolve-v0",
"PIKA_DISSOLVE A slender glass vase, brimming with tiny white pebbles, stands centered on a polished ebony dais. Without warning, the glass begins to dissolve from the edges inward. Wisps of translucent dust swirl upward in an elegant spiral, illuminating each pebble as they drop onto the dais. The gently drifting dust eventually settles, leaving only the scattered stones and faint traces of shimmering powder on the stage.",
"inconsistent motion, blurry motion, worse quality, degenerate outputs, deformed outputs",
50, 512, 512, 50,
"example_outputs/pika-dissolve-v0.mp4"
],
[
"finetrainers/crush-smol-v0",
"DIFF_crush A thick burger is placed on a dining table, and a large metal cylinder descends from above, crushing the burger as if it were under a hydraulic press. The bulb is crushed, leaving a pile of debris around it.",
"inconsistent motion, blurry motion, worse quality, degenerate outputs, deformed outputs",
50, 512, 512, 50,
"example_outputs/crush-smol-v0.mp4"
],
[
"finetrainers/3dgs-v0",
"3D_dissolve In a 3D appearance, a bookshelf filled with books is surrounded by a burst of red sparks, creating a dramatic and explosive effect against a black background.",
"inconsistent motion, blurry motion, worse quality, degenerate outputs, deformed outputs",
50, 512, 512, 50,
"example_outputs/3dgs-v0.mp4"
],
[
"finetrainers/cakeify-v0",
"PIKA_CAKEIFY On a gleaming glass display stand, a sleek black purse quietly commands attention. Suddenly, a knife appears and slices through the shoe, revealing a fluffy vanilla sponge at its core. Immediately, it turns into a hyper-realistic prop cake, delighting the senses with its playful juxtaposition of the everyday and the extraordinary.",
"inconsistent motion, blurry motion, worse quality, degenerate outputs, deformed outputs",
50, 512, 512, 50,
"example_outputs/cakeify-v0.mp4"
]
],
inputs=[
model_dropdown,
prompt_input,
negative_prompt_input,
num_frames,
height,
width,
num_inference_steps,
video_output,
],
label="Prompt Examples"
)
# Connect the function
generate_btn.click(
fn=generate_video,
inputs=[
model_dropdown,
prompt_input,
negative_prompt_input,
num_frames,
height,
width,
num_inference_steps
],
outputs=video_output
)
return demo
# Launch the application
if __name__ == "__main__":
demo = create_interface()
demo.launch()
|