Spaces:
Build error
Build error
Last commit not found
import gradio as gr | |
from diffusers import CogVideoXTransformer3DModel, DiffusionPipeline | |
from diffusers.utils import export_to_video | |
import torch | |
import tempfile | |
import os | |
import spaces | |
# Lista de modelos disponibles | |
TRANSFORMER_MODELS = [ | |
"sayakpaul/pika-dissolve-v0", | |
"finetrainers/crush-smol-v0", | |
"finetrainers/3dgs-v0", | |
"finetrainers/cakeify-v0" | |
] | |
def generate_video(transformer_model, prompt, negative_prompt, num_frames, height, width, num_inference_steps): | |
# Cargar el modelo del transformer seleccionado | |
transformer = CogVideoXTransformer3DModel.from_pretrained( | |
transformer_model, | |
torch_dtype=torch.bfloat16 | |
) | |
# Inicializar el pipeline | |
pipeline = DiffusionPipeline.from_pretrained( | |
"THUDM/CogVideoX-5b", | |
transformer=transformer, | |
torch_dtype=torch.bfloat16 | |
).to("cuda") | |
# Generar el video | |
video_frames = pipeline( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
num_frames=num_frames, | |
height=height, | |
width=width, | |
num_inference_steps=num_inference_steps | |
).frames[0] | |
# Guardar el video en un archivo temporal | |
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file: | |
export_to_video(video_frames, tmp_file.name, fps=25) | |
return tmp_file.name | |
# Crear la interfaz de Gradio | |
with gr.Blocks() as demo: | |
gr.Markdown("# Generador de Videos con CogVideoX") | |
with gr.Row(): | |
with gr.Column(): | |
# Entradas | |
model_dropdown = gr.Dropdown( | |
choices=TRANSFORMER_MODELS, | |
value=TRANSFORMER_MODELS[0], | |
label="Modelo Transformer" | |
) | |
prompt_input = gr.Textbox( | |
lines=5, | |
label="Prompt", | |
placeholder="Describe el video que quieres generar..." | |
) | |
negative_prompt_input = gr.Textbox( | |
lines=2, | |
label="Prompt Negativo", | |
value="inconsistent motion, blurry motion, worse quality, degenerate outputs, deformed outputs" | |
) | |
with gr.Accordion("Parámetros Avanzados", open=False): | |
num_frames = gr.Slider( | |
minimum=8, | |
maximum=128, | |
value=8, | |
step=1, | |
label="Número de Frames", | |
info="Cantidad de frames en el video" | |
) | |
height = gr.Slider( | |
minimum=32, | |
maximum=1024, | |
value=32, | |
step=64, | |
label="Altura", | |
info="Altura del video en píxeles" | |
) | |
width = gr.Slider( | |
minimum=32, | |
maximum=1024, | |
value=32, | |
step=64, | |
label="Anchura", | |
info="Anchura del video en píxeles" | |
) | |
num_inference_steps = gr.Slider( | |
minimum=10, | |
maximum=100, | |
value=10, | |
step=1, | |
label="Pasos de Inferencia", | |
info="Mayor número = mejor calidad pero más lento" | |
) | |
generate_btn = gr.Button("Generar Video") | |
with gr.Column(): | |
# Salida | |
video_output = gr.Video(label="Video Generado") | |
# Conectar la función | |
generate_btn.click( | |
fn=generate_video, | |
inputs=[ | |
model_dropdown, | |
prompt_input, | |
negative_prompt_input, | |
num_frames, | |
height, | |
width, | |
num_inference_steps | |
], | |
outputs=video_output | |
) | |
# Lanzar la aplicación | |
demo.launch() | |