xierui.0097 commited on
Commit
cf76b2c
·
1 Parent(s): d2b5b0e
video_to_video/video_to_video_model.py CHANGED
@@ -17,10 +17,10 @@ from diffusers import AutoencoderKLTemporalDecoder
17
  import requests
18
 
19
  def download_model(url, model_path):
20
- if not os.path.exists(model_path):
21
  print(f"Model not found at {model_path}, downloading...")
22
  response = requests.get(url, stream=True)
23
- with open(model_path, 'wb') as f:
24
  for chunk in response.iter_content(chunk_size=1024):
25
  if chunk:
26
  f.write(chunk)
@@ -37,7 +37,7 @@ class VideoToVideo_sr():
37
  self.device = device # torch.device(f'cuda:0')
38
 
39
  # text_encoder
40
- text_encoder = FrozenOpenCLIPEmbedder(device=self.device, pretrained="laion2b_s32b_b79k")
41
  text_encoder.model.to(self.device)
42
  self.text_encoder = text_encoder
43
  logger.info(f'Build encoder with FrozenOpenCLIPEmbedder')
@@ -80,7 +80,7 @@ class VideoToVideo_sr():
80
 
81
  # Temporal VAE
82
  vae = AutoencoderKLTemporalDecoder.from_pretrained(
83
- "stabilityai/stable-video-diffusion-img2vid", subfolder="vae", variant="fp16"
84
  )
85
  vae.eval()
86
  vae.requires_grad_(False)
 
17
  import requests
18
 
19
  def download_model(url, model_path):
20
+ if not os.path.exists(os.path.join(model_path, 'heavy_deg.pt')):
21
  print(f"Model not found at {model_path}, downloading...")
22
  response = requests.get(url, stream=True)
23
+ with open(os.path.join(model_path, 'heavy_deg.pt'), 'wb') as f:
24
  for chunk in response.iter_content(chunk_size=1024):
25
  if chunk:
26
  f.write(chunk)
 
37
  self.device = device # torch.device(f'cuda:0')
38
 
39
  # text_encoder
40
+ text_encoder = FrozenOpenCLIPEmbedder(device=self.device, pretrained="/home/test/Workspace/yhliu/VSR/ours/checkpoints/open_clip_pytorch_model.bin")
41
  text_encoder.model.to(self.device)
42
  self.text_encoder = text_encoder
43
  logger.info(f'Build encoder with FrozenOpenCLIPEmbedder')
 
80
 
81
  # Temporal VAE
82
  vae = AutoencoderKLTemporalDecoder.from_pretrained(
83
+ "/home/test/Workspace/yhliu/VSR/ours/checkpoints/stable-video-diffusion-img2vid", subfolder="vae", variant="fp16"
84
  )
85
  vae.eval()
86
  vae.requires_grad_(False)