seawolf2357 commited on
Commit
eb70557
·
verified ·
1 Parent(s): feda343

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -12
app.py CHANGED
@@ -108,11 +108,9 @@ def generate_image(prompt: str):
108
  # @spaces.GPU(duration=300, gpu_type="l40s")
109
  def infer(prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, frames=64):
110
  try:
111
- # 이미지 생성
112
  image_path = generate_image(prompt)
113
  image = torchvision.io.read_image(image_path).float() / 255.0
114
 
115
- # 한글 입력 확인 및 번역
116
  if any('\u3131' <= char <= '\u318E' or '\uAC00' <= char <= '\uD7A3' for char in prompt):
117
  translated = translator(prompt, max_length=512)
118
  prompt = translated[0]['translation_text']
@@ -120,9 +118,7 @@ def infer(prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, frames=64):
120
  resolution = (576, 1024)
121
  save_fps = 8
122
  seed_everything(seed)
123
- transform = transforms.Compose([
124
- transforms.Resize(resolution, antialias=True),
125
- ])
126
 
127
  print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())))
128
  start = time.time()
@@ -130,30 +126,30 @@ def infer(prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, frames=64):
130
  steps = 60
131
 
132
  batch_size = 1
133
- channels = model.module.model.diffusion_model.out_channels
134
  h, w = resolution[0] // 8, resolution[1] // 8
135
  noise_shape = [batch_size, channels, frames, h, w]
136
 
137
  with torch.no_grad(), torch.cuda.amp.autocast():
138
- text_emb = model.module.get_learned_conditioning([prompt])
139
 
140
  img_tensor = image.to(torch.cuda.current_device())
141
  img_tensor = (img_tensor - 0.5) * 2
142
  image_tensor_resized = transform(img_tensor)
143
  videos = image_tensor_resized.unsqueeze(0)
144
 
145
- z = get_latent_z(model.module, videos.unsqueeze(2))
146
  img_tensor_repeat = repeat(z, 'b c t h w -> b c (repeat t) h w', repeat=frames)
147
 
148
- cond_images = model.module.embedder(img_tensor.unsqueeze(0))
149
- img_emb = model.module.image_proj_model(cond_images)
150
 
151
  imtext_cond = torch.cat([text_emb, img_emb], dim=1)
152
 
153
  fs = torch.tensor([fs], dtype=torch.long, device=torch.cuda.current_device())
154
  cond = {"c_crossattn": [imtext_cond], "fs": fs, "c_concat": [img_tensor_repeat]}
155
 
156
- batch_samples = batch_ddim_sampling(model.module, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale)
157
 
158
  video_path = './output.mp4'
159
  save_videos(batch_samples, './', filenames=['output'], fps=save_fps)
@@ -168,7 +164,8 @@ def infer(prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, frames=64):
168
  print(f"Error occurred: {e}")
169
  return None
170
  finally:
171
- torch.cuda.empty_cache()
 
172
 
173
  i2v_examples = [
174
  ['우주인 복장으로 기타를 치는 남자', 30, 7.5, 1.0, 6, 123, 64],
 
108
  # @spaces.GPU(duration=300, gpu_type="l40s")
109
  def infer(prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, frames=64):
110
  try:
 
111
  image_path = generate_image(prompt)
112
  image = torchvision.io.read_image(image_path).float() / 255.0
113
 
 
114
  if any('\u3131' <= char <= '\u318E' or '\uAC00' <= char <= '\uD7A3' for char in prompt):
115
  translated = translator(prompt, max_length=512)
116
  prompt = translated[0]['translation_text']
 
118
  resolution = (576, 1024)
119
  save_fps = 8
120
  seed_everything(seed)
121
+ transform = transforms.Compose([transforms.Resize(resolution, antialias=True)])
 
 
122
 
123
  print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())))
124
  start = time.time()
 
126
  steps = 60
127
 
128
  batch_size = 1
129
+ channels = model.diffusion_model.out_channels # model.module 제거
130
  h, w = resolution[0] // 8, resolution[1] // 8
131
  noise_shape = [batch_size, channels, frames, h, w]
132
 
133
  with torch.no_grad(), torch.cuda.amp.autocast():
134
+ text_emb = model.get_learned_conditioning([prompt]) # model.module 제거
135
 
136
  img_tensor = image.to(torch.cuda.current_device())
137
  img_tensor = (img_tensor - 0.5) * 2
138
  image_tensor_resized = transform(img_tensor)
139
  videos = image_tensor_resized.unsqueeze(0)
140
 
141
+ z = get_latent_z(model, videos.unsqueeze(2)) # model.module 제거
142
  img_tensor_repeat = repeat(z, 'b c t h w -> b c (repeat t) h w', repeat=frames)
143
 
144
+ cond_images = model.embedder(img_tensor.unsqueeze(0)) # model.module 제거
145
+ img_emb = model.image_proj_model(cond_images) # model.module 제거
146
 
147
  imtext_cond = torch.cat([text_emb, img_emb], dim=1)
148
 
149
  fs = torch.tensor([fs], dtype=torch.long, device=torch.cuda.current_device())
150
  cond = {"c_crossattn": [imtext_cond], "fs": fs, "c_concat": [img_tensor_repeat]}
151
 
152
+ batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale) # model.module 제거
153
 
154
  video_path = './output.mp4'
155
  save_videos(batch_samples, './', filenames=['output'], fps=save_fps)
 
164
  print(f"Error occurred: {e}")
165
  return None
166
  finally:
167
+ torch.cuda.empty_cache()
168
+
169
 
170
  i2v_examples = [
171
  ['우주인 복장으로 기타를 치는 남자', 30, 7.5, 1.0, 6, 123, 64],