Refactor CogVideoX app with modular functions and improved code structure
Browse files
app.py
CHANGED
@@ -5,7 +5,8 @@ import torch
|
|
5 |
import tempfile
|
6 |
import os
|
7 |
import spaces
|
8 |
-
|
|
|
9 |
TRANSFORMER_MODELS = [
|
10 |
"sayakpaul/pika-dissolve-v0",
|
11 |
"finetrainers/crush-smol-v0",
|
@@ -13,24 +14,42 @@ TRANSFORMER_MODELS = [
|
|
13 |
"finetrainers/cakeify-v0"
|
14 |
]
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
#
|
|
|
19 |
transformer = CogVideoXTransformer3DModel.from_pretrained(
|
20 |
transformer_model,
|
21 |
torch_dtype=torch.bfloat16
|
22 |
)
|
23 |
-
|
24 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
25 |
|
26 |
-
#
|
|
|
27 |
pipeline = DiffusionPipeline.from_pretrained(
|
28 |
"THUDM/CogVideoX-5b",
|
29 |
transformer=transformer,
|
30 |
torch_dtype=torch.bfloat16
|
31 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
-
#
|
|
|
34 |
video_frames = pipeline(
|
35 |
prompt=prompt,
|
36 |
negative_prompt=negative_prompt,
|
@@ -40,128 +59,151 @@ def generate_video(transformer_model, prompt, negative_prompt, num_frames, heigh
|
|
40 |
num_inference_steps=num_inference_steps
|
41 |
).frames[0]
|
42 |
|
43 |
-
|
44 |
-
|
45 |
-
export_to_video(video_frames, tmp_file.name, fps=25)
|
46 |
-
return tmp_file.name
|
47 |
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
51 |
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
)
|
80 |
-
height = gr.Slider(
|
81 |
-
minimum=32,
|
82 |
-
maximum=1024,
|
83 |
-
value=256,
|
84 |
-
step=64,
|
85 |
-
label="Height",
|
86 |
-
info="Video height in pixels"
|
87 |
)
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
step=64,
|
93 |
-
label="Width",
|
94 |
-
info="Video width in pixels"
|
95 |
)
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
value=
|
100 |
-
step=1,
|
101 |
-
label="Inference Steps",
|
102 |
-
info="Higher number = better quality but slower"
|
103 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
-
|
|
|
|
|
106 |
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
],
|
120 |
-
[
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
|
|
|
|
|
|
125 |
],
|
126 |
-
|
127 |
-
|
128 |
-
"3D_dissolve In a 3D appearance, a bookshelf filled with books is surrounded by a burst of red sparks, creating a dramatic and explosive effect against a black background.",
|
129 |
-
"inconsistent motion, blurry motion, worse quality, degenerate outputs, deformed outputs",
|
130 |
-
8, 32, 32, 10
|
131 |
-
],
|
132 |
-
[
|
133 |
-
"finetrainers/cakeify-v0",
|
134 |
-
"PIKA_CAKEIFY On a gleaming glass display stand, a sleek black purse quietly commands attention. Suddenly, a knife appears and slices through the shoe, revealing a fluffy vanilla sponge at its core. Immediately, it turns into a hyper-realistic prop cake, delighting the senses with its playful juxtaposition of the everyday and the extraordinary.",
|
135 |
-
"inconsistent motion, blurry motion, worse quality, degenerate outputs, deformed outputs",
|
136 |
-
8, 32, 32, 10
|
137 |
-
]
|
138 |
-
],
|
139 |
-
inputs=[
|
140 |
-
model_dropdown,
|
141 |
-
prompt_input,
|
142 |
-
negative_prompt_input,
|
143 |
-
num_frames,
|
144 |
-
height,
|
145 |
-
width,
|
146 |
-
num_inference_steps
|
147 |
-
],
|
148 |
-
label="Prompt Examples"
|
149 |
-
)
|
150 |
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
|
|
|
|
165 |
|
166 |
-
#
|
167 |
-
|
|
|
|
|
|
5 |
import tempfile
|
6 |
import os
|
7 |
import spaces
|
8 |
+
|
9 |
+
# Available transformer models
|
10 |
TRANSFORMER_MODELS = [
|
11 |
"sayakpaul/pika-dissolve-v0",
|
12 |
"finetrainers/crush-smol-v0",
|
|
|
14 |
"finetrainers/cakeify-v0"
|
15 |
]
|
16 |
|
17 |
+
def load_models(transformer_model):
|
18 |
+
"""Load transformer and pipeline models"""
|
19 |
+
# Load the selected transformer model
|
20 |
+
print(f"Loading model: {transformer_model}")
|
21 |
transformer = CogVideoXTransformer3DModel.from_pretrained(
|
22 |
transformer_model,
|
23 |
torch_dtype=torch.bfloat16
|
24 |
)
|
|
|
|
|
25 |
|
26 |
+
# Initialize the pipeline
|
27 |
+
print("Initializing pipeline")
|
28 |
pipeline = DiffusionPipeline.from_pretrained(
|
29 |
"THUDM/CogVideoX-5b",
|
30 |
transformer=transformer,
|
31 |
torch_dtype=torch.bfloat16
|
32 |
+
)
|
33 |
+
|
34 |
+
return pipeline
|
35 |
+
|
36 |
+
def save_video(video_frames, fps=25):
|
37 |
+
"""Save video frames to a temporary file"""
|
38 |
+
print("Saving video")
|
39 |
+
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file:
|
40 |
+
export_to_video(video_frames, tmp_file.name, fps=fps)
|
41 |
+
return tmp_file.name
|
42 |
+
|
43 |
+
@spaces.GPU
|
44 |
+
def generate_video_pipeline(pipeline, prompt, negative_prompt, num_frames, height, width, num_inference_steps):
|
45 |
+
"""Generate video using the pipeline"""
|
46 |
+
# Move to appropriate device
|
47 |
+
print("Moving to device")
|
48 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
49 |
+
pipeline = pipeline.to(device)
|
50 |
|
51 |
+
# Generate video
|
52 |
+
print("Generating video")
|
53 |
video_frames = pipeline(
|
54 |
prompt=prompt,
|
55 |
negative_prompt=negative_prompt,
|
|
|
59 |
num_inference_steps=num_inference_steps
|
60 |
).frames[0]
|
61 |
|
62 |
+
print("Video generated")
|
63 |
+
return video_frames
|
|
|
|
|
64 |
|
65 |
+
def generate_video(transformer_model, prompt, negative_prompt, num_frames, height, width, num_inference_steps):
|
66 |
+
"""Main function to handle the video generation process"""
|
67 |
+
# Load models
|
68 |
+
pipeline = load_models(transformer_model)
|
69 |
|
70 |
+
# Generate video frames
|
71 |
+
video_frames = generate_video_pipeline(
|
72 |
+
pipeline,
|
73 |
+
prompt,
|
74 |
+
negative_prompt,
|
75 |
+
num_frames,
|
76 |
+
height,
|
77 |
+
width,
|
78 |
+
num_inference_steps
|
79 |
+
)
|
80 |
+
|
81 |
+
# Save and return video path
|
82 |
+
print("Saving video")
|
83 |
+
return save_video(video_frames)
|
84 |
+
|
85 |
+
def create_interface():
|
86 |
+
"""Create and configure the Gradio interface"""
|
87 |
+
with gr.Blocks() as demo:
|
88 |
+
gr.Markdown("# CogVideoX Video Generator")
|
89 |
+
|
90 |
+
with gr.Row():
|
91 |
+
with gr.Column():
|
92 |
+
# Inputs
|
93 |
+
model_dropdown = gr.Dropdown(
|
94 |
+
choices=TRANSFORMER_MODELS,
|
95 |
+
value=TRANSFORMER_MODELS[0],
|
96 |
+
label="Transformer Model"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
)
|
98 |
+
prompt_input = gr.Textbox(
|
99 |
+
lines=5,
|
100 |
+
label="Prompt",
|
101 |
+
placeholder="Describe the video you want to generate..."
|
|
|
|
|
|
|
102 |
)
|
103 |
+
negative_prompt_input = gr.Textbox(
|
104 |
+
lines=2,
|
105 |
+
label="Negative Prompt",
|
106 |
+
value="inconsistent motion, blurry motion, worse quality, degenerate outputs, deformed outputs"
|
|
|
|
|
|
|
107 |
)
|
108 |
+
|
109 |
+
with gr.Accordion("Advanced Parameters", open=False):
|
110 |
+
num_frames = gr.Slider(
|
111 |
+
minimum=8,
|
112 |
+
maximum=128,
|
113 |
+
value=50,
|
114 |
+
step=1,
|
115 |
+
label="Number of Frames",
|
116 |
+
info="Number of frames in the video"
|
117 |
+
)
|
118 |
+
height = gr.Slider(
|
119 |
+
minimum=32,
|
120 |
+
maximum=1024,
|
121 |
+
value=224,
|
122 |
+
step=64,
|
123 |
+
label="Height",
|
124 |
+
info="Video height in pixels"
|
125 |
+
)
|
126 |
+
width = gr.Slider(
|
127 |
+
minimum=32,
|
128 |
+
maximum=1024,
|
129 |
+
value=224,
|
130 |
+
step=64,
|
131 |
+
label="Width",
|
132 |
+
info="Video width in pixels"
|
133 |
+
)
|
134 |
+
num_inference_steps = gr.Slider(
|
135 |
+
minimum=10,
|
136 |
+
maximum=100,
|
137 |
+
value=50,
|
138 |
+
step=1,
|
139 |
+
label="Inference Steps",
|
140 |
+
info="Higher number = better quality but slower"
|
141 |
+
)
|
142 |
+
|
143 |
+
generate_btn = gr.Button("Generate Video")
|
144 |
|
145 |
+
with gr.Column():
|
146 |
+
# Output
|
147 |
+
video_output = gr.Video(label="Generated Video")
|
148 |
|
149 |
+
# Add examples
|
150 |
+
gr.Examples(
|
151 |
+
examples=[
|
152 |
+
[
|
153 |
+
"sayakpaul/pika-dissolve-v0",
|
154 |
+
"PIKA_DISSOLVE A slender glass vase, brimming with tiny white pebbles, stands centered on a polished ebony dais. Without warning, the glass begins to dissolve from the edges inward. Wisps of translucent dust swirl upward in an elegant spiral, illuminating each pebble as they drop onto the dais. The gently drifting dust eventually settles, leaving only the scattered stones and faint traces of shimmering powder on the stage.",
|
155 |
+
"inconsistent motion, blurry motion, worse quality, degenerate outputs, deformed outputs",
|
156 |
+
8, 32, 32, 10
|
157 |
+
],
|
158 |
+
[
|
159 |
+
"finetrainers/crush-smol-v0",
|
160 |
+
"DIFF_crush A thick burger is placed on a dining table, and a large metal cylinder descends from above, crushing the burger as if it were under a hydraulic press. The bulb is crushed, leaving a pile of debris around it.",
|
161 |
+
"inconsistent motion, blurry motion, worse quality, degenerate outputs, deformed outputs",
|
162 |
+
8, 32, 32, 10
|
163 |
+
],
|
164 |
+
[
|
165 |
+
"finetrainers/3dgs-v0",
|
166 |
+
"3D_dissolve In a 3D appearance, a bookshelf filled with books is surrounded by a burst of red sparks, creating a dramatic and explosive effect against a black background.",
|
167 |
+
"inconsistent motion, blurry motion, worse quality, degenerate outputs, deformed outputs",
|
168 |
+
8, 32, 32, 10
|
169 |
+
],
|
170 |
+
[
|
171 |
+
"finetrainers/cakeify-v0",
|
172 |
+
"PIKA_CAKEIFY On a gleaming glass display stand, a sleek black purse quietly commands attention. Suddenly, a knife appears and slices through the shoe, revealing a fluffy vanilla sponge at its core. Immediately, it turns into a hyper-realistic prop cake, delighting the senses with its playful juxtaposition of the everyday and the extraordinary.",
|
173 |
+
"inconsistent motion, blurry motion, worse quality, degenerate outputs, deformed outputs",
|
174 |
+
8, 32, 32, 10
|
175 |
+
]
|
176 |
],
|
177 |
+
inputs=[
|
178 |
+
model_dropdown,
|
179 |
+
prompt_input,
|
180 |
+
negative_prompt_input,
|
181 |
+
num_frames,
|
182 |
+
height,
|
183 |
+
width,
|
184 |
+
num_inference_steps
|
185 |
],
|
186 |
+
label="Prompt Examples"
|
187 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
|
189 |
+
# Connect the function
|
190 |
+
generate_btn.click(
|
191 |
+
fn=generate_video,
|
192 |
+
inputs=[
|
193 |
+
model_dropdown,
|
194 |
+
prompt_input,
|
195 |
+
negative_prompt_input,
|
196 |
+
num_frames,
|
197 |
+
height,
|
198 |
+
width,
|
199 |
+
num_inference_steps
|
200 |
+
],
|
201 |
+
outputs=video_output
|
202 |
+
)
|
203 |
+
|
204 |
+
return demo
|
205 |
|
206 |
+
# Launch the application
|
207 |
+
if __name__ == "__main__":
|
208 |
+
demo = create_interface()
|
209 |
+
demo.launch()
|