Maximofn commited on
Commit
a203df4
·
1 Parent(s): 4cfb1f8

Refactor CogVideoX app with modular functions and improved code structure

Browse files
Files changed (1) hide show
  1. app.py +165 -123
app.py CHANGED
@@ -5,7 +5,8 @@ import torch
5
  import tempfile
6
  import os
7
  import spaces
8
- # Lista de modelos disponibles
 
9
  TRANSFORMER_MODELS = [
10
  "sayakpaul/pika-dissolve-v0",
11
  "finetrainers/crush-smol-v0",
@@ -13,24 +14,42 @@ TRANSFORMER_MODELS = [
13
  "finetrainers/cakeify-v0"
14
  ]
15
 
16
- @spaces.GPU
17
- def generate_video(transformer_model, prompt, negative_prompt, num_frames, height, width, num_inference_steps):
18
- # Cargar el modelo del transformer seleccionado
 
19
  transformer = CogVideoXTransformer3DModel.from_pretrained(
20
  transformer_model,
21
  torch_dtype=torch.bfloat16
22
  )
23
-
24
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
 
26
- # Inicializar el pipeline
 
27
  pipeline = DiffusionPipeline.from_pretrained(
28
  "THUDM/CogVideoX-5b",
29
  transformer=transformer,
30
  torch_dtype=torch.bfloat16
31
- ).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- # Generar el video
 
34
  video_frames = pipeline(
35
  prompt=prompt,
36
  negative_prompt=negative_prompt,
@@ -40,128 +59,151 @@ def generate_video(transformer_model, prompt, negative_prompt, num_frames, heigh
40
  num_inference_steps=num_inference_steps
41
  ).frames[0]
42
 
43
- # Guardar el video en un archivo temporal
44
- with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file:
45
- export_to_video(video_frames, tmp_file.name, fps=25)
46
- return tmp_file.name
47
 
48
- # Crear la interfaz de Gradio
49
- with gr.Blocks() as demo:
50
- gr.Markdown("# CogVideoX Video Generator")
 
51
 
52
- with gr.Row():
53
- with gr.Column():
54
- # Entradas
55
- model_dropdown = gr.Dropdown(
56
- choices=TRANSFORMER_MODELS,
57
- value=TRANSFORMER_MODELS[0],
58
- label="Transformer Model"
59
- )
60
- prompt_input = gr.Textbox(
61
- lines=5,
62
- label="Prompt",
63
- placeholder="Describe the video you want to generate..."
64
- )
65
- negative_prompt_input = gr.Textbox(
66
- lines=2,
67
- label="Negative Prompt",
68
- value="inconsistent motion, blurry motion, worse quality, degenerate outputs, deformed outputs"
69
- )
70
-
71
- with gr.Accordion("Advanced Parameters", open=False):
72
- num_frames = gr.Slider(
73
- minimum=8,
74
- maximum=128,
75
- value=50,
76
- step=1,
77
- label="Number of Frames",
78
- info="Number of frames in the video"
79
- )
80
- height = gr.Slider(
81
- minimum=32,
82
- maximum=1024,
83
- value=256,
84
- step=64,
85
- label="Height",
86
- info="Video height in pixels"
87
  )
88
- width = gr.Slider(
89
- minimum=32,
90
- maximum=1024,
91
- value=256,
92
- step=64,
93
- label="Width",
94
- info="Video width in pixels"
95
  )
96
- num_inference_steps = gr.Slider(
97
- minimum=10,
98
- maximum=100,
99
- value=50,
100
- step=1,
101
- label="Inference Steps",
102
- info="Higher number = better quality but slower"
103
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
- generate_btn = gr.Button("Generate Video")
 
 
106
 
107
- with gr.Column():
108
- # Salida
109
- video_output = gr.Video(label="Generated Video")
110
-
111
- # Agregar ejemplos
112
- gr.Examples(
113
- examples=[
114
- [
115
- "sayakpaul/pika-dissolve-v0",
116
- "PIKA_DISSOLVE A slender glass vase, brimming with tiny white pebbles, stands centered on a polished ebony dais. Without warning, the glass begins to dissolve from the edges inward. Wisps of translucent dust swirl upward in an elegant spiral, illuminating each pebble as they drop onto the dais. The gently drifting dust eventually settles, leaving only the scattered stones and faint traces of shimmering powder on the stage.",
117
- "inconsistent motion, blurry motion, worse quality, degenerate outputs, deformed outputs",
118
- 8, 32, 32, 10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  ],
120
- [
121
- "finetrainers/crush-smol-v0",
122
- "DIFF_crush A thick burger is placed on a dining table, and a large metal cylinder descends from above, crushing the burger as if it were under a hydraulic press. The bulb is crushed, leaving a pile of debris around it.",
123
- "inconsistent motion, blurry motion, worse quality, degenerate outputs, deformed outputs",
124
- 8, 32, 32, 10
 
 
 
125
  ],
126
- [
127
- "finetrainers/3dgs-v0",
128
- "3D_dissolve In a 3D appearance, a bookshelf filled with books is surrounded by a burst of red sparks, creating a dramatic and explosive effect against a black background.",
129
- "inconsistent motion, blurry motion, worse quality, degenerate outputs, deformed outputs",
130
- 8, 32, 32, 10
131
- ],
132
- [
133
- "finetrainers/cakeify-v0",
134
- "PIKA_CAKEIFY On a gleaming glass display stand, a sleek black purse quietly commands attention. Suddenly, a knife appears and slices through the shoe, revealing a fluffy vanilla sponge at its core. Immediately, it turns into a hyper-realistic prop cake, delighting the senses with its playful juxtaposition of the everyday and the extraordinary.",
135
- "inconsistent motion, blurry motion, worse quality, degenerate outputs, deformed outputs",
136
- 8, 32, 32, 10
137
- ]
138
- ],
139
- inputs=[
140
- model_dropdown,
141
- prompt_input,
142
- negative_prompt_input,
143
- num_frames,
144
- height,
145
- width,
146
- num_inference_steps
147
- ],
148
- label="Prompt Examples"
149
- )
150
 
151
- # Conectar la función
152
- generate_btn.click(
153
- fn=generate_video,
154
- inputs=[
155
- model_dropdown,
156
- prompt_input,
157
- negative_prompt_input,
158
- num_frames,
159
- height,
160
- width,
161
- num_inference_steps
162
- ],
163
- outputs=video_output
164
- )
 
 
165
 
166
- # Lanzar la aplicación
167
- demo.launch()
 
 
 
5
  import tempfile
6
  import os
7
  import spaces
8
+
9
+ # Available transformer models
10
  TRANSFORMER_MODELS = [
11
  "sayakpaul/pika-dissolve-v0",
12
  "finetrainers/crush-smol-v0",
 
14
  "finetrainers/cakeify-v0"
15
  ]
16
 
17
+ def load_models(transformer_model):
18
+ """Load transformer and pipeline models"""
19
+ # Load the selected transformer model
20
+ print(f"Loading model: {transformer_model}")
21
  transformer = CogVideoXTransformer3DModel.from_pretrained(
22
  transformer_model,
23
  torch_dtype=torch.bfloat16
24
  )
 
 
25
 
26
+ # Initialize the pipeline
27
+ print("Initializing pipeline")
28
  pipeline = DiffusionPipeline.from_pretrained(
29
  "THUDM/CogVideoX-5b",
30
  transformer=transformer,
31
  torch_dtype=torch.bfloat16
32
+ )
33
+
34
+ return pipeline
35
+
36
+ def save_video(video_frames, fps=25):
37
+ """Save video frames to a temporary file"""
38
+ print("Saving video")
39
+ with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file:
40
+ export_to_video(video_frames, tmp_file.name, fps=fps)
41
+ return tmp_file.name
42
+
43
+ @spaces.GPU
44
+ def generate_video_pipeline(pipeline, prompt, negative_prompt, num_frames, height, width, num_inference_steps):
45
+ """Generate video using the pipeline"""
46
+ # Move to appropriate device
47
+ print("Moving to device")
48
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
+ pipeline = pipeline.to(device)
50
 
51
+ # Generate video
52
+ print("Generating video")
53
  video_frames = pipeline(
54
  prompt=prompt,
55
  negative_prompt=negative_prompt,
 
59
  num_inference_steps=num_inference_steps
60
  ).frames[0]
61
 
62
+ print("Video generated")
63
+ return video_frames
 
 
64
 
65
+ def generate_video(transformer_model, prompt, negative_prompt, num_frames, height, width, num_inference_steps):
66
+ """Main function to handle the video generation process"""
67
+ # Load models
68
+ pipeline = load_models(transformer_model)
69
 
70
+ # Generate video frames
71
+ video_frames = generate_video_pipeline(
72
+ pipeline,
73
+ prompt,
74
+ negative_prompt,
75
+ num_frames,
76
+ height,
77
+ width,
78
+ num_inference_steps
79
+ )
80
+
81
+ # Save and return video path
82
+ print("Saving video")
83
+ return save_video(video_frames)
84
+
85
+ def create_interface():
86
+ """Create and configure the Gradio interface"""
87
+ with gr.Blocks() as demo:
88
+ gr.Markdown("# CogVideoX Video Generator")
89
+
90
+ with gr.Row():
91
+ with gr.Column():
92
+ # Inputs
93
+ model_dropdown = gr.Dropdown(
94
+ choices=TRANSFORMER_MODELS,
95
+ value=TRANSFORMER_MODELS[0],
96
+ label="Transformer Model"
 
 
 
 
 
 
 
 
97
  )
98
+ prompt_input = gr.Textbox(
99
+ lines=5,
100
+ label="Prompt",
101
+ placeholder="Describe the video you want to generate..."
 
 
 
102
  )
103
+ negative_prompt_input = gr.Textbox(
104
+ lines=2,
105
+ label="Negative Prompt",
106
+ value="inconsistent motion, blurry motion, worse quality, degenerate outputs, deformed outputs"
 
 
 
107
  )
108
+
109
+ with gr.Accordion("Advanced Parameters", open=False):
110
+ num_frames = gr.Slider(
111
+ minimum=8,
112
+ maximum=128,
113
+ value=50,
114
+ step=1,
115
+ label="Number of Frames",
116
+ info="Number of frames in the video"
117
+ )
118
+ height = gr.Slider(
119
+ minimum=32,
120
+ maximum=1024,
121
+ value=224,
122
+ step=64,
123
+ label="Height",
124
+ info="Video height in pixels"
125
+ )
126
+ width = gr.Slider(
127
+ minimum=32,
128
+ maximum=1024,
129
+ value=224,
130
+ step=64,
131
+ label="Width",
132
+ info="Video width in pixels"
133
+ )
134
+ num_inference_steps = gr.Slider(
135
+ minimum=10,
136
+ maximum=100,
137
+ value=50,
138
+ step=1,
139
+ label="Inference Steps",
140
+ info="Higher number = better quality but slower"
141
+ )
142
+
143
+ generate_btn = gr.Button("Generate Video")
144
 
145
+ with gr.Column():
146
+ # Output
147
+ video_output = gr.Video(label="Generated Video")
148
 
149
+ # Add examples
150
+ gr.Examples(
151
+ examples=[
152
+ [
153
+ "sayakpaul/pika-dissolve-v0",
154
+ "PIKA_DISSOLVE A slender glass vase, brimming with tiny white pebbles, stands centered on a polished ebony dais. Without warning, the glass begins to dissolve from the edges inward. Wisps of translucent dust swirl upward in an elegant spiral, illuminating each pebble as they drop onto the dais. The gently drifting dust eventually settles, leaving only the scattered stones and faint traces of shimmering powder on the stage.",
155
+ "inconsistent motion, blurry motion, worse quality, degenerate outputs, deformed outputs",
156
+ 8, 32, 32, 10
157
+ ],
158
+ [
159
+ "finetrainers/crush-smol-v0",
160
+ "DIFF_crush A thick burger is placed on a dining table, and a large metal cylinder descends from above, crushing the burger as if it were under a hydraulic press. The bulb is crushed, leaving a pile of debris around it.",
161
+ "inconsistent motion, blurry motion, worse quality, degenerate outputs, deformed outputs",
162
+ 8, 32, 32, 10
163
+ ],
164
+ [
165
+ "finetrainers/3dgs-v0",
166
+ "3D_dissolve In a 3D appearance, a bookshelf filled with books is surrounded by a burst of red sparks, creating a dramatic and explosive effect against a black background.",
167
+ "inconsistent motion, blurry motion, worse quality, degenerate outputs, deformed outputs",
168
+ 8, 32, 32, 10
169
+ ],
170
+ [
171
+ "finetrainers/cakeify-v0",
172
+ "PIKA_CAKEIFY On a gleaming glass display stand, a sleek black purse quietly commands attention. Suddenly, a knife appears and slices through the shoe, revealing a fluffy vanilla sponge at its core. Immediately, it turns into a hyper-realistic prop cake, delighting the senses with its playful juxtaposition of the everyday and the extraordinary.",
173
+ "inconsistent motion, blurry motion, worse quality, degenerate outputs, deformed outputs",
174
+ 8, 32, 32, 10
175
+ ]
176
  ],
177
+ inputs=[
178
+ model_dropdown,
179
+ prompt_input,
180
+ negative_prompt_input,
181
+ num_frames,
182
+ height,
183
+ width,
184
+ num_inference_steps
185
  ],
186
+ label="Prompt Examples"
187
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
+ # Connect the function
190
+ generate_btn.click(
191
+ fn=generate_video,
192
+ inputs=[
193
+ model_dropdown,
194
+ prompt_input,
195
+ negative_prompt_input,
196
+ num_frames,
197
+ height,
198
+ width,
199
+ num_inference_steps
200
+ ],
201
+ outputs=video_output
202
+ )
203
+
204
+ return demo
205
 
206
+ # Launch the application
207
+ if __name__ == "__main__":
208
+ demo = create_interface()
209
+ demo.launch()