Spaces:
Runtime error
Runtime error
Update opensora/serve/gradio_web_server.py
Browse files
opensora/serve/gradio_web_server.py
CHANGED
@@ -24,24 +24,26 @@ from opensora.models.diffusion.latte.modeling_latte import LatteT2V
|
|
24 |
from opensora.sample.pipeline_videogen import VideoGenPipeline
|
25 |
from opensora.serve.gradio_utils import block_css, title_markdown, randomize_seed_fn, set_env, examples, DESCRIPTION
|
26 |
|
|
|
27 |
|
28 |
-
@
|
29 |
def generate_img(prompt, sample_steps, scale, seed=0, randomize_seed=False, force_images=False):
|
30 |
seed = int(randomize_seed_fn(seed, randomize_seed))
|
31 |
set_env(seed)
|
32 |
video_length = transformer_model.config.video_length if not force_images else 1
|
33 |
height, width = int(args.version.split('x')[1]), int(args.version.split('x')[2])
|
34 |
num_frames = 1 if video_length == 1 else int(args.version.split('x')[0])
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
|
|
45 |
|
46 |
torch.cuda.empty_cache()
|
47 |
videos = videos[0]
|
|
|
24 |
from opensora.sample.pipeline_videogen import VideoGenPipeline
|
25 |
from opensora.serve.gradio_utils import block_css, title_markdown, randomize_seed_fn, set_env, examples, DESCRIPTION
|
26 |
|
27 |
+
import space
|
28 |
|
29 |
+
@spaces.GPU
|
30 |
def generate_img(prompt, sample_steps, scale, seed=0, randomize_seed=False, force_images=False):
|
31 |
seed = int(randomize_seed_fn(seed, randomize_seed))
|
32 |
set_env(seed)
|
33 |
video_length = transformer_model.config.video_length if not force_images else 1
|
34 |
height, width = int(args.version.split('x')[1]), int(args.version.split('x')[2])
|
35 |
num_frames = 1 if video_length == 1 else int(args.version.split('x')[0])
|
36 |
+
with torch.no_grad():
|
37 |
+
videos = videogen_pipeline(prompt,
|
38 |
+
video_length=video_length,
|
39 |
+
height=height,
|
40 |
+
width=width,
|
41 |
+
num_inference_steps=sample_steps,
|
42 |
+
guidance_scale=scale,
|
43 |
+
enable_temporal_attentions=not force_images,
|
44 |
+
num_images_per_prompt=1,
|
45 |
+
mask_feature=True,
|
46 |
+
).video
|
47 |
|
48 |
torch.cuda.empty_cache()
|
49 |
videos = videos[0]
|