Spaces:
Runtime error
Runtime error
Update open_oasis_master/generate.py
Browse files
open_oasis_master/generate.py
CHANGED
@@ -15,13 +15,13 @@ from torch import autocast
|
|
15 |
device = "cpu"
|
16 |
|
17 |
# load DiT checkpoint
|
18 |
-
ckpt = torch.load("oasis500m.pt")
|
19 |
model = DiT_models["DiT-S/2"]()
|
20 |
model.load_state_dict(ckpt, strict=False)
|
21 |
model = model.to(device).eval()
|
22 |
|
23 |
# load VAE checkpoint
|
24 |
-
vae_ckpt = torch.load("vit-l-20.pt")
|
25 |
vae = VAE_models["vit-l-20-shallow-encoder"]()
|
26 |
vae.load_state_dict(vae_ckpt)
|
27 |
vae = vae.to(device).eval()
|
@@ -40,7 +40,7 @@ video_id = "snippy-chartreuse-mastiff-f79998db196d-20220401-224517.chunk_001"
|
|
40 |
mp4_path = f"sample_data/{video_id}.mp4"
|
41 |
actions_path = f"sample_data/{video_id}.actions.pt"
|
42 |
video = read_video(mp4_path, pts_unit="sec")[0].float() / 255
|
43 |
-
actions = one_hot_actions(torch.load(actions_path))
|
44 |
offset = 100
|
45 |
video = video[offset:offset+total_frames].unsqueeze(0)
|
46 |
actions = actions[offset:offset+total_frames].unsqueeze(0)
|
|
|
15 |
device = "cpu"
|
16 |
|
17 |
# load DiT checkpoint
|
18 |
+
ckpt = torch.load("oasis500m.pt",map_location=torch.device('cpu'))
|
19 |
model = DiT_models["DiT-S/2"]()
|
20 |
model.load_state_dict(ckpt, strict=False)
|
21 |
model = model.to(device).eval()
|
22 |
|
23 |
# load VAE checkpoint
|
24 |
+
vae_ckpt = torch.load("vit-l-20.pt",map_location=torch.device('cpu'))
|
25 |
vae = VAE_models["vit-l-20-shallow-encoder"]()
|
26 |
vae.load_state_dict(vae_ckpt)
|
27 |
vae = vae.to(device).eval()
|
|
|
40 |
mp4_path = f"sample_data/{video_id}.mp4"
|
41 |
actions_path = f"sample_data/{video_id}.actions.pt"
|
42 |
video = read_video(mp4_path, pts_unit="sec")[0].float() / 255
|
43 |
+
actions = one_hot_actions(torch.load(actions_path,map_location=torch.device('cpu')))
|
44 |
offset = 100
|
45 |
video = video[offset:offset+total_frames].unsqueeze(0)
|
46 |
actions = actions[offset:offset+total_frames].unsqueeze(0)
|