Nirav-Madhani commited on
Commit
a61e149
·
verified ·
1 Parent(s): b6cc7a0

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +25 -0
main.py CHANGED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from octo.model.octo_model import OctoModel
2
+ from PIL import Image
3
+ import requests
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import jax
7
+ os.environ['JAX_PLATFORMS'] = 'cpu'
8
+ model = OctoModel.load_pretrained("hf://rail-berkeley/octo-small-1.5")
9
+
10
+ # download one example BridgeV2 image
11
+ IMAGE_URL = "https://rail.eecs.berkeley.edu/datasets/bridge_release/raw/bridge_data_v2/datacol2_toykitchen7/drawer_pnp/01/2023-04-19_09-18-15/raw/traj_group0/traj0/images0/im_12.jpg"
12
+ img = np.array(Image.open(requests.get(IMAGE_URL, stream=True).raw).resize((256, 256)))
13
+
14
+
15
+ # add batch + time horizon 1
16
+ img = img[np.newaxis,np.newaxis,...]
17
+ observation = {"image_primary": img, "timestep_pad_mask": np.array([[True]])}
18
+ task = model.create_tasks(texts=["pick up the fork"])
19
+ action = model.sample_actions(
20
+ observation,
21
+ task,
22
+ unnormalization_statistics=model.dataset_statistics["bridge_dataset"]["action"],
23
+ rng=jax.random.PRNGKey(0)
24
+ )
25
+ print(action)