seawolf2357 commited on
Commit
85c0e7a
·
verified ·
1 Parent(s): eb70557

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -13
app.py CHANGED
@@ -118,7 +118,9 @@ def infer(prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, frames=64):
118
  resolution = (576, 1024)
119
  save_fps = 8
120
  seed_everything(seed)
121
- transform = transforms.Compose([transforms.Resize(resolution, antialias=True)])
 
 
122
 
123
  print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())))
124
  start = time.time()
@@ -126,38 +128,32 @@ def infer(prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, frames=64):
126
  steps = 60
127
 
128
  batch_size = 1
129
- channels = model.diffusion_model.out_channels # model.module 제거
130
- h, w = resolution[0] // 8, resolution[1] // 8
131
- noise_shape = [batch_size, channels, frames, h, w]
132
 
133
  with torch.no_grad(), torch.cuda.amp.autocast():
134
- text_emb = model.get_learned_conditioning([prompt]) # model.module 제거
135
 
136
  img_tensor = image.to(torch.cuda.current_device())
137
  img_tensor = (img_tensor - 0.5) * 2
138
  image_tensor_resized = transform(img_tensor)
139
  videos = image_tensor_resized.unsqueeze(0)
140
 
141
- z = get_latent_z(model, videos.unsqueeze(2)) # model.module 제거
142
  img_tensor_repeat = repeat(z, 'b c t h w -> b c (repeat t) h w', repeat=frames)
143
 
144
- cond_images = model.embedder(img_tensor.unsqueeze(0)) # model.module 제거
145
- img_emb = model.image_proj_model(cond_images) # model.module 제거
146
 
147
  imtext_cond = torch.cat([text_emb, img_emb], dim=1)
148
 
149
  fs = torch.tensor([fs], dtype=torch.long, device=torch.cuda.current_device())
150
  cond = {"c_crossattn": [imtext_cond], "fs": fs, "c_concat": [img_tensor_repeat]}
151
 
152
- batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale) # model.module 제거
153
 
154
  video_path = './output.mp4'
155
  save_videos(batch_samples, './', filenames=['output'], fps=save_fps)
156
 
157
- # 메모리 정리
158
- del text_emb, img_tensor, image_tensor_resized, videos, z, img_tensor_repeat, cond_images, img_emb, imtext_cond, cond, batch_samples
159
- torch.cuda.empty_cache()
160
-
161
  return video_path
162
 
163
  except Exception as e:
@@ -167,6 +163,7 @@ def infer(prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, frames=64):
167
  torch.cuda.empty_cache()
168
 
169
 
 
170
  i2v_examples = [
171
  ['우주인 복장으로 기타를 치는 남자', 30, 7.5, 1.0, 6, 123, 64],
172
  ['time-lapse of a blooming flower with leaves and a stem', 30, 7.5, 1.0, 10, 123, 64],
 
118
  resolution = (576, 1024)
119
  save_fps = 8
120
  seed_everything(seed)
121
+ transform = transforms.Compose([
122
+ transforms.Resize(resolution, antialias=True),
123
+ ])
124
 
125
  print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())))
126
  start = time.time()
 
128
  steps = 60
129
 
130
  batch_size = 1
131
+ channels = model.model.out_channels # 수정된 부분
 
 
132
 
133
  with torch.no_grad(), torch.cuda.amp.autocast():
134
+ text_emb = model.get_learned_conditioning([prompt])
135
 
136
  img_tensor = image.to(torch.cuda.current_device())
137
  img_tensor = (img_tensor - 0.5) * 2
138
  image_tensor_resized = transform(img_tensor)
139
  videos = image_tensor_resized.unsqueeze(0)
140
 
141
+ z = get_latent_z(model, videos.unsqueeze(2))
142
  img_tensor_repeat = repeat(z, 'b c t h w -> b c (repeat t) h w', repeat=frames)
143
 
144
+ cond_images = model.embedder(img_tensor.unsqueeze(0))
145
+ img_emb = model.image_proj_model(cond_images)
146
 
147
  imtext_cond = torch.cat([text_emb, img_emb], dim=1)
148
 
149
  fs = torch.tensor([fs], dtype=torch.long, device=torch.cuda.current_device())
150
  cond = {"c_crossattn": [imtext_cond], "fs": fs, "c_concat": [img_tensor_repeat]}
151
 
152
+ batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale)
153
 
154
  video_path = './output.mp4'
155
  save_videos(batch_samples, './', filenames=['output'], fps=save_fps)
156
 
 
 
 
 
157
  return video_path
158
 
159
  except Exception as e:
 
163
  torch.cuda.empty_cache()
164
 
165
 
166
+
167
  i2v_examples = [
168
  ['우주인 복장으로 기타를 치는 남자', 30, 7.5, 1.0, 6, 123, 64],
169
  ['time-lapse of a blooming flower with leaves and a stem', 30, 7.5, 1.0, 10, 123, 64],