File size: 1,332 Bytes
a61e149
 
 
 
 
 
ab273a0
a61e149
 
 
 
 
 
 
 
 
 
 
 
0558e79
 
 
 
 
a61e149
0558e79
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from octo.model.octo_model import OctoModel
from PIL import Image
import requests
import matplotlib.pyplot as plt
import numpy as np
import jax
import os
os.environ['JAX_PLATFORMS'] = 'cpu'
model = OctoModel.load_pretrained("hf://rail-berkeley/octo-small-1.5")

# download one example BridgeV2 image
IMAGE_URL = "https://rail.eecs.berkeley.edu/datasets/bridge_release/raw/bridge_data_v2/datacol2_toykitchen7/drawer_pnp/01/2023-04-19_09-18-15/raw/traj_group0/traj0/images0/im_12.jpg"
img = np.array(Image.open(requests.get(IMAGE_URL, stream=True).raw).resize((256, 256)))


# add batch + time horizon 1
img = img[np.newaxis,np.newaxis,...]
observation = {"image_primary": img, "timestep_pad_mask": np.array([[True]])}
task = model.create_tasks(texts=["pick up the fork"])
norm_actions = model.sample_actions(observation, task, rng=jax.random.PRNGKey(0))
norm_actions = norm_actions[0]   # remove batch
actions = (
    norm_actions * model.dataset_statistics["bridge_dataset"]['action']['std']
    + model.dataset_statistics["bridge_dataset"]['action']['mean']
)
actions = np.concatenate(
        (
            steps[step+1]['action']['world_vector'],
            steps[step+1]['action']['rotation_delta'],
            np.array(steps[step+1]['action']['open_gripper']).astype(np.float32)[None]
        ), axis=-1
    )
print(actions)