tree3po commited on
Commit
833a58b
·
verified ·
1 Parent(s): ea0a988

Update open_oasis_master/generate.py

Browse files
Files changed (1) hide show
  1. open_oasis_master/generate.py +3 -3
open_oasis_master/generate.py CHANGED
@@ -15,13 +15,13 @@ from torch import autocast
15
  device = "cpu"
16
 
17
  # load DiT checkpoint
18
- ckpt = torch.load("oasis500m.pt")
19
  model = DiT_models["DiT-S/2"]()
20
  model.load_state_dict(ckpt, strict=False)
21
  model = model.to(device).eval()
22
 
23
  # load VAE checkpoint
24
- vae_ckpt = torch.load("vit-l-20.pt")
25
  vae = VAE_models["vit-l-20-shallow-encoder"]()
26
  vae.load_state_dict(vae_ckpt)
27
  vae = vae.to(device).eval()
@@ -40,7 +40,7 @@ video_id = "snippy-chartreuse-mastiff-f79998db196d-20220401-224517.chunk_001"
40
  mp4_path = f"sample_data/{video_id}.mp4"
41
  actions_path = f"sample_data/{video_id}.actions.pt"
42
  video = read_video(mp4_path, pts_unit="sec")[0].float() / 255
43
- actions = one_hot_actions(torch.load(actions_path))
44
  offset = 100
45
  video = video[offset:offset+total_frames].unsqueeze(0)
46
  actions = actions[offset:offset+total_frames].unsqueeze(0)
 
15
  device = "cpu"
16
 
17
  # load DiT checkpoint
18
+ ckpt = torch.load("oasis500m.pt",map_location=torch.device('cpu'))
19
  model = DiT_models["DiT-S/2"]()
20
  model.load_state_dict(ckpt, strict=False)
21
  model = model.to(device).eval()
22
 
23
  # load VAE checkpoint
24
+ vae_ckpt = torch.load("vit-l-20.pt",map_location=torch.device('cpu'))
25
  vae = VAE_models["vit-l-20-shallow-encoder"]()
26
  vae.load_state_dict(vae_ckpt)
27
  vae = vae.to(device).eval()
 
40
  mp4_path = f"sample_data/{video_id}.mp4"
41
  actions_path = f"sample_data/{video_id}.actions.pt"
42
  video = read_video(mp4_path, pts_unit="sec")[0].float() / 255
43
+ actions = one_hot_actions(torch.load(actions_path,map_location=torch.device('cpu')))
44
  offset = 100
45
  video = video[offset:offset+total_frames].unsqueeze(0)
46
  actions = actions[offset:offset+total_frames].unsqueeze(0)