Nirav Madhani commited on
Commit
0558e79
·
1 Parent(s): ab273a0

Flash server

Browse files
Files changed (5) hide show
  1. Dockerfile +27 -3
  2. app.py +65 -0
  3. init_model.py +9 -0
  4. main.py +13 -6
  5. test_api.py +54 -0
Dockerfile CHANGED
@@ -1,11 +1,35 @@
1
- FROM python:3.10
 
 
 
 
 
2
 
3
  WORKDIR /app
 
 
4
  RUN git clone https://github.com/octo-models/octo.git
5
  WORKDIR /app/octo
 
 
6
  RUN pip3 install -e .
7
  RUN pip3 install -r requirements.txt
8
  RUN pip3 install --upgrade "jax[cuda11_pip]==0.4.20" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
9
- RUN pip install scipy==1.10.1
 
 
 
 
 
 
 
 
 
10
  COPY main.py /app/octo
11
- CMD ["python", "main.py"]
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ # Update package list and install git
4
+ RUN apt-get update && \
5
+ apt-get install -y git && \
6
+ rm -rf /var/lib/apt/lists/* # Clean up to reduce image size
7
 
8
  WORKDIR /app
9
+
10
+ # Clone the octo repository
11
  RUN git clone https://github.com/octo-models/octo.git
12
  WORKDIR /app/octo
13
+
14
+ # Install dependencies
15
  RUN pip3 install -e .
16
  RUN pip3 install -r requirements.txt
17
  RUN pip3 install --upgrade "jax[cuda11_pip]==0.4.20" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
18
+ RUN pip3 install scipy==1.10.1
19
+
20
+ # Install FastAPI and Uvicorn for the API
21
+ RUN pip3 install fastapi uvicorn
22
+
23
+ # Copy and run the model initialization script to cache the model
24
+ COPY init_model.py /app/octo
25
+ RUN python init_model.py
26
+
27
+ # Copy the original main.py and the API app.py
28
  COPY main.py /app/octo
29
+ COPY app.py /app/octo
30
+
31
+ # Expose port 8000 for the API
32
+ EXPOSE 8000
33
+
34
+ # Run the API with Uvicorn
35
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from octo.model.octo_model import OctoModel
2
+ from PIL import Image
3
+ import numpy as np
4
+ import jax
5
+ from fastapi import FastAPI, HTTPException
6
+ from pydantic import BaseModel
7
+ import os
8
+ import io
9
+ import base64
10
+
11
+ # Set JAX to use CPU platform (adjust if GPU is needed)
12
+ os.environ['JAX_PLATFORMS'] = 'cpu'
13
+
14
+ # Load the model once globally (assumes it's cached locally)
15
+ model = OctoModel.load_pretrained("hf://rail-berkeley/octo-small-1.5")
16
+
17
+ # Initialize FastAPI app
18
+ app = FastAPI(title="Octo Model Inference API")
19
+
20
+ # Define request body model
21
+ class InferenceRequest(BaseModel):
22
+ image_base64: str # Base64-encoded image string
23
+ task: str = "pick up the fork" # Default task
24
+
25
+ # Health check endpoint
26
+ @app.get("/health")
27
+ async def health_check():
28
+ return {"status": "healthy"}
29
+
30
+ # Inference endpoint
31
+ @app.post("/predict")
32
+ async def predict(request: InferenceRequest):
33
+ try:
34
+ # Decode base64 image
35
+ img_base64 = request.image_base64
36
+ if img_base64.startswith("data:image"):
37
+ img_base64 = img_base64.split(",")[1]
38
+
39
+ img_data = base64.b64decode(img_base64)
40
+ img = Image.open(io.BytesIO(img_data)).resize((256, 256))
41
+ img = np.array(img)
42
+
43
+ # Add batch and time horizon dimensions
44
+ img = img[np.newaxis, np.newaxis, ...] # Shape: (1, 1, 256, 256, 3)
45
+ observation = {
46
+ "image_primary": img,
47
+ "timestep_pad_mask": np.array([[True]])
48
+ }
49
+
50
+ # Create task and predict actions
51
+ task_obj = model.create_tasks(texts=[request.task])
52
+ actions = model.sample_actions(
53
+ observation,
54
+ task_obj,
55
+ unnormalization_statistics=model.dataset_statistics["bridge_dataset"]["action"],
56
+ rng=jax.random.PRNGKey(0)
57
+ )
58
+ actions = actions[0]
59
+
60
+ # Convert NumPy array to list for JSON response
61
+ actions_list = actions.tolist()
62
+
63
+ return {"actions": actions_list}
64
+ except Exception as e:
65
+ raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")
init_model.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from octo.model.octo_model import OctoModel
2
+ import os
3
+
4
+ # Set JAX to CPU (consistent with your setup)
5
+ os.environ['JAX_PLATFORMS'] = 'cpu'
6
+
7
+ # Load the model to cache it
8
+ model = OctoModel.load_pretrained("hf://rail-berkeley/octo-small-1.5")
9
+ print("Model downloaded and cached successfully.")
main.py CHANGED
@@ -17,10 +17,17 @@ img = np.array(Image.open(requests.get(IMAGE_URL, stream=True).raw).resize((256,
17
  img = img[np.newaxis,np.newaxis,...]
18
  observation = {"image_primary": img, "timestep_pad_mask": np.array([[True]])}
19
  task = model.create_tasks(texts=["pick up the fork"])
20
- action = model.sample_actions(
21
- observation,
22
- task,
23
- unnormalization_statistics=model.dataset_statistics["bridge_dataset"]["action"],
24
- rng=jax.random.PRNGKey(0)
25
  )
26
- print(action)
 
 
 
 
 
 
 
 
17
  img = img[np.newaxis,np.newaxis,...]
18
  observation = {"image_primary": img, "timestep_pad_mask": np.array([[True]])}
19
  task = model.create_tasks(texts=["pick up the fork"])
20
+ norm_actions = model.sample_actions(observation, task, rng=jax.random.PRNGKey(0))
21
+ norm_actions = norm_actions[0] # remove batch
22
+ actions = (
23
+ norm_actions * model.dataset_statistics["bridge_dataset"]['action']['std']
24
+ + model.dataset_statistics["bridge_dataset"]['action']['mean']
25
  )
26
+ actions = np.concatenate(
27
+ (
28
+ steps[step+1]['action']['world_vector'],
29
+ steps[step+1]['action']['rotation_delta'],
30
+ np.array(steps[step+1]['action']['open_gripper']).astype(np.float32)[None]
31
+ ), axis=-1
32
+ )
33
+ print(actions)
test_api.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import base64
3
+ from PIL import Image
4
+ import io
5
+
6
+ # API endpoint URL (adjust if running on a different host/port)
7
+ API_URL = "http://localhost:8000/predict"
8
+
9
+ # Example image URL from main.py
10
+ IMAGE_URL = "https://rail.eecs.berkeley.edu/datasets/bridge_release/raw/bridge_data_v2/datacol2_toykitchen7/drawer_pnp/01/2023-04-19_09-18-15/raw/traj_group0/traj0/images0/im_12.jpg"
11
+ TASK_TEXT = "pick up the fork"
12
+
13
+ def test_api(image_url=IMAGE_URL, task=TASK_TEXT):
14
+ try:
15
+ # Download image from URL
16
+ response = requests.get(image_url, stream=True)
17
+ response.raise_for_status() # Check for HTTP errors
18
+ img = Image.open(response.raw).resize((256, 256))
19
+
20
+ # Convert image to base64
21
+ img_byte_arr = io.BytesIO()
22
+ img.save(img_byte_arr, format="JPEG") # Save as JPEG (adjust if needed)
23
+ img_byte_arr = img_byte_arr.getvalue()
24
+ base64_string = base64.b64encode(img_byte_arr).decode("utf-8")
25
+
26
+ # Prepare payload for API
27
+ payload = {
28
+ "image_base64": base64_string,
29
+ "task": task
30
+ }
31
+
32
+ # Send POST request to API
33
+ api_response = requests.post(API_URL, json=payload)
34
+ api_response.raise_for_status() # Check for API errors
35
+
36
+ # Print the result
37
+ result = api_response.json()
38
+ print(f"Task: {task}")
39
+ print(f"Image URL: {image_url}")
40
+ print(f"Predicted Actions: {result['actions']}")
41
+
42
+ except requests.exceptions.RequestException as e:
43
+ print(f"Error fetching image or calling API: {e}")
44
+ except Exception as e:
45
+ print(f"Unexpected error: {e}")
46
+
47
+ if __name__ == "__main__":
48
+ # Test with default values (same as main.py)
49
+ test_api()
50
+
51
+ # Test with a different URL and task (optional)
52
+ # Replace with another valid URL if desired
53
+ print("\nTesting with another URL and task:")
54
+ test_api(IMAGE_URL, TASK_TEXT)