fantaxy commited on
Commit
915ecc0
·
verified ·
1 Parent(s): 87bc4ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -18
app.py CHANGED
@@ -2,7 +2,6 @@ import spaces
2
  import gradio as gr
3
  import os
4
  import sys
5
- import argparse
6
  import random
7
  import time
8
  from omegaconf import OmegaConf
@@ -31,23 +30,19 @@ def download_model():
31
  if not os.path.exists(local_file):
32
  hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/dynamicrafter_1024_v1/', force_download=True)
33
 
34
-
35
-
36
  download_model()
37
  ckpt_path='checkpoints/dynamicrafter_1024_v1/model.ckpt'
38
  config_file='configs/inference_1024_v1.0.yaml'
39
  config = OmegaConf.load(config_file)
40
  model_config = config.pop("model", OmegaConf.create())
41
- model_config['params']['unet_config']['params']['use_checkpoint']=False
42
  model = instantiate_from_config(model_config)
43
  assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!"
44
  model = load_model_checkpoint(model, ckpt_path)
45
  model.eval()
46
  model = model.cuda()
47
 
48
-
49
-
50
- @spaces.GPU(duration=300)
51
  def infer(image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123):
52
  resolution = (576, 1024)
53
  save_fps = 8
@@ -56,31 +51,29 @@ def infer(image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123):
56
  transforms.Resize(min(resolution)),
57
  transforms.CenterCrop(resolution),
58
  ])
59
- torch.cuda.empty_cache()
60
- print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
61
  start = time.time()
62
  if steps > 60:
63
  steps = 60
64
 
65
- batch_size=1
66
  channels = model.model.diffusion_model.out_channels
67
  frames = model.temporal_length
68
  h, w = resolution[0] // 8, resolution[1] // 8
69
  noise_shape = [batch_size, channels, frames, h, w]
70
 
71
- # text cond
72
- with torch.no_grad(), torch.cuda.amp.autocast():
73
  text_emb = model.get_learned_conditioning([prompt])
74
-
75
- # img cond
76
  img_tensor = torch.from_numpy(image).permute(2, 0, 1).float().to(model.device)
77
  img_tensor = (img_tensor / 255. - 0.5) * 2
78
-
79
  image_tensor_resized = transform(img_tensor) #3,256,256
80
  videos = image_tensor_resized.unsqueeze(0) # bchw
81
 
82
  z = get_latent_z(model, videos.unsqueeze(2)) #bc,1,hw
83
-
84
  img_tensor_repeat = repeat(z, 'b c t h w -> b c (repeat t) h w', repeat=frames)
85
 
86
  cond_images = model.embedder(img_tensor.unsqueeze(0)) ## blc
@@ -91,9 +84,9 @@ def infer(image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123):
91
  fs = torch.tensor([fs], dtype=torch.long, device=model.device)
92
  cond = {"c_crossattn": [imtext_cond], "fs": fs, "c_concat": [img_tensor_repeat]}
93
 
94
- ## inference
95
  batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale)
96
- ## b,samples,c,t,h,w
97
 
98
  video_path = './output.mp4'
99
  save_videos(batch_samples, './', filenames=['output'], fps=save_fps)
 
2
  import gradio as gr
3
  import os
4
  import sys
 
5
  import random
6
  import time
7
  from omegaconf import OmegaConf
 
30
  if not os.path.exists(local_file):
31
  hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/dynamicrafter_1024_v1/', force_download=True)
32
 
 
 
33
  download_model()
34
  ckpt_path='checkpoints/dynamicrafter_1024_v1/model.ckpt'
35
  config_file='configs/inference_1024_v1.0.yaml'
36
  config = OmegaConf.load(config_file)
37
  model_config = config.pop("model", OmegaConf.create())
38
+ model_config['params']['unet_config']['params']['use_checkpoint']=True # Checkpoint 사용하여 메모리 사용 최적화
39
  model = instantiate_from_config(model_config)
40
  assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!"
41
  model = load_model_checkpoint(model, ckpt_path)
42
  model.eval()
43
  model = model.cuda()
44
 
45
+ @spaces.GPU(duration=300, gpu_type="h100") # H100 GPU 사용 지정
 
 
46
  def infer(image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123):
47
  resolution = (576, 1024)
48
  save_fps = 8
 
51
  transforms.Resize(min(resolution)),
52
  transforms.CenterCrop(resolution),
53
  ])
54
+ torch.cuda.empty_cache() # GPU 캐시 메모리 정리
55
+ print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())))
56
  start = time.time()
57
  if steps > 60:
58
  steps = 60
59
 
60
+ batch_size = 1
61
  channels = model.model.diffusion_model.out_channels
62
  frames = model.temporal_length
63
  h, w = resolution[0] // 8, resolution[1] // 8
64
  noise_shape = [batch_size, channels, frames, h, w]
65
 
66
+ # 텍스트 조건 생성
67
+ with torch.no_grad(), torch.cuda.amp.autocast(): # 메모리 사용량 감소 및 연산 속도 개선
68
  text_emb = model.get_learned_conditioning([prompt])
69
+
70
+ # 이미지 조건 생성
71
  img_tensor = torch.from_numpy(image).permute(2, 0, 1).float().to(model.device)
72
  img_tensor = (img_tensor / 255. - 0.5) * 2
 
73
  image_tensor_resized = transform(img_tensor) #3,256,256
74
  videos = image_tensor_resized.unsqueeze(0) # bchw
75
 
76
  z = get_latent_z(model, videos.unsqueeze(2)) #bc,1,hw
 
77
  img_tensor_repeat = repeat(z, 'b c t h w -> b c (repeat t) h w', repeat=frames)
78
 
79
  cond_images = model.embedder(img_tensor.unsqueeze(0)) ## blc
 
84
  fs = torch.tensor([fs], dtype=torch.long, device=model.device)
85
  cond = {"c_crossattn": [imtext_cond], "fs": fs, "c_concat": [img_tensor_repeat]}
86
 
87
+ # 추론
88
  batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale)
89
+ # b,samples,c,t,h,w
90
 
91
  video_path = './output.mp4'
92
  save_videos(batch_samples, './', filenames=['output'], fps=save_fps)