seawolf2357 commited on
Commit
5e9229f
·
verified ·
1 Parent(s): e17e16f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -15
app.py CHANGED
@@ -22,17 +22,14 @@ from funcs import (
22
  )
23
  from transformers import pipeline
24
  from diffusers import StableDiffusionXLPipeline
25
- #import spaces
26
  import tensorflow as tf
27
- print(tf.__version__)
28
 
 
29
  print("GPU available:", len(tf.config.list_physical_devices('GPU')) > 0)
30
 
31
-
32
-
33
  def is_tensor(x):
34
  return tf.is_tensor(x)
35
-
36
  os.environ['KERAS_BACKEND'] = 'tensorflow'
37
 
38
  def download_model():
@@ -46,19 +43,23 @@ def download_model():
46
  hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/dynamicrafter_1024_v1/', force_download=True)
47
 
48
  download_model()
49
- ckpt_path='checkpoints/dynamicrafter_1024_v1/model.ckpt'
50
- config_file='configs/inference_1024_v1.0.yaml'
51
  config = OmegaConf.load(config_file)
52
  model_config = config.pop("model", OmegaConf.create())
53
- model_config['params']['unet_config']['params']['use_checkpoint']=True
54
  model = instantiate_from_config(model_config)
55
  assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!"
56
  model = load_model_checkpoint(model, ckpt_path)
57
  model.eval()
 
 
 
58
  model = model.cuda()
59
 
60
  # 번역 모델 로드
61
  translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en", device=0 if torch.cuda.is_available() else -1, framework="pt")
 
62
  # 이미지 생성 모델 로드
63
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
64
  pipe = StableDiffusionXLPipeline.from_pretrained(
@@ -126,30 +127,30 @@ def infer(prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, frames=64):
126
  steps = 60
127
 
128
  batch_size = 1
129
- channels = model.model.diffusion_model.out_channels
130
  h, w = resolution[0] // 8, resolution[1] // 8
131
  noise_shape = [batch_size, channels, frames, h, w]
132
 
133
  with torch.no_grad(), torch.cuda.amp.autocast():
134
- text_emb = model.get_learned_conditioning([prompt])
135
 
136
  img_tensor = image.to(model.device)
137
  img_tensor = (img_tensor - 0.5) * 2
138
  image_tensor_resized = transform(img_tensor)
139
  videos = image_tensor_resized.unsqueeze(0)
140
 
141
- z = get_latent_z(model, videos.unsqueeze(2))
142
  img_tensor_repeat = repeat(z, 'b c t h w -> b c (repeat t) h w', repeat=frames)
143
 
144
- cond_images = model.embedder(img_tensor.unsqueeze(0))
145
- img_emb = model.image_proj_model(cond_images)
146
 
147
  imtext_cond = torch.cat([text_emb, img_emb], dim=1)
148
 
149
  fs = torch.tensor([fs], dtype=torch.long, device=model.device)
150
  cond = {"c_crossattn": [imtext_cond], "fs": fs, "c_concat": [img_tensor_repeat]}
151
 
152
- batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale)
153
 
154
  video_path = './output.mp4'
155
  save_videos(batch_samples, './', filenames=['output'], fps=save_fps)
@@ -203,5 +204,4 @@ with gr.Blocks(analytics_enabled=False, css=css) as dynamicrafter_iface:
203
  fn = infer
204
  )
205
 
206
- # dynamicrafter_iface.launch(server_port=7930, server_name="0.0.0.0", share=True)
207
  dynamicrafter_iface.launch()
 
22
  )
23
  from transformers import pipeline
24
  from diffusers import StableDiffusionXLPipeline
 
25
  import tensorflow as tf
 
26
 
27
+ print(tf.__version__)
28
  print("GPU available:", len(tf.config.list_physical_devices('GPU')) > 0)
29
 
 
 
30
  def is_tensor(x):
31
  return tf.is_tensor(x)
32
+
33
  os.environ['KERAS_BACKEND'] = 'tensorflow'
34
 
35
  def download_model():
 
43
  hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/dynamicrafter_1024_v1/', force_download=True)
44
 
45
  download_model()
46
+ ckpt_path = 'checkpoints/dynamicrafter_1024_v1/model.ckpt'
47
+ config_file = 'configs/inference_1024_v1.0.yaml'
48
  config = OmegaConf.load(config_file)
49
  model_config = config.pop("model", OmegaConf.create())
50
+ model_config['params']['unet_config']['params']['use_checkpoint'] = True
51
  model = instantiate_from_config(model_config)
52
  assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!"
53
  model = load_model_checkpoint(model, ckpt_path)
54
  model.eval()
55
+
56
+ # 모델을 DataParallel로 감싸서 여러 GPU에서 실행 가능하게 설정
57
+ model = torch.nn.DataParallel(model)
58
  model = model.cuda()
59
 
60
  # 번역 모델 로드
61
  translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en", device=0 if torch.cuda.is_available() else -1, framework="pt")
62
+
63
  # 이미지 생성 모델 로드
64
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
65
  pipe = StableDiffusionXLPipeline.from_pretrained(
 
127
  steps = 60
128
 
129
  batch_size = 1
130
+ channels = model.module.model.diffusion_model.out_channels # DataParallel로 감싼 경우 model.module로 접근
131
  h, w = resolution[0] // 8, resolution[1] // 8
132
  noise_shape = [batch_size, channels, frames, h, w]
133
 
134
  with torch.no_grad(), torch.cuda.amp.autocast():
135
+ text_emb = model.module.get_learned_conditioning([prompt])
136
 
137
  img_tensor = image.to(model.device)
138
  img_tensor = (img_tensor - 0.5) * 2
139
  image_tensor_resized = transform(img_tensor)
140
  videos = image_tensor_resized.unsqueeze(0)
141
 
142
+ z = get_latent_z(model.module, videos.unsqueeze(2))
143
  img_tensor_repeat = repeat(z, 'b c t h w -> b c (repeat t) h w', repeat=frames)
144
 
145
+ cond_images = model.module.embedder(img_tensor.unsqueeze(0))
146
+ img_emb = model.module.image_proj_model(cond_images)
147
 
148
  imtext_cond = torch.cat([text_emb, img_emb], dim=1)
149
 
150
  fs = torch.tensor([fs], dtype=torch.long, device=model.device)
151
  cond = {"c_crossattn": [imtext_cond], "fs": fs, "c_concat": [img_tensor_repeat]}
152
 
153
+ batch_samples = batch_ddim_sampling(model.module, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale)
154
 
155
  video_path = './output.mp4'
156
  save_videos(batch_samples, './', filenames=['output'], fps=save_fps)
 
204
  fn = infer
205
  )
206
 
 
207
  dynamicrafter_iface.launch()