from octo.model.octo_model import OctoModel from PIL import Image import requests import matplotlib.pyplot as plt import numpy as np import jax import os os.environ['JAX_PLATFORMS'] = 'cpu' model = OctoModel.load_pretrained("hf://rail-berkeley/octo-small-1.5") # download one example BridgeV2 image IMAGE_URL = "https://rail.eecs.berkeley.edu/datasets/bridge_release/raw/bridge_data_v2/datacol2_toykitchen7/drawer_pnp/01/2023-04-19_09-18-15/raw/traj_group0/traj0/images0/im_12.jpg" img = np.array(Image.open(requests.get(IMAGE_URL, stream=True).raw).resize((256, 256))) # add batch + time horizon 1 img = img[np.newaxis,np.newaxis,...] observation = {"image_primary": img, "timestep_pad_mask": np.array([[True]])} task = model.create_tasks(texts=["pick up the fork"]) action = model.sample_actions( observation, task, unnormalization_statistics=model.dataset_statistics["bridge_dataset"]["action"], rng=jax.random.PRNGKey(0) ) print(action)