from octo.model.octo_model import OctoModel from PIL import Image import numpy as np import jax from fastapi import FastAPI, HTTPException from pydantic import BaseModel import os import io import base64 from typing import List # Set JAX to use CPU (adjust to GPU if available) os.environ['JAX_PLATFORMS'] = 'cpu' # Load Octo 1.5 model globally model = OctoModel.load_pretrained("hf://rail-berkeley/octo-small-1.5") # Initialize FastAPI app app = FastAPI(title="Octo 1.5 Inference API") # Request body model class InferenceRequest(BaseModel): image_base64: List[str] # List of base64-encoded images task: str = "pick up the fork" # Default task # Health check endpoint @app.get("/health") async def health_check(): return {"status": "healthy"} # Inference endpoint @app.post("/predict") async def predict(request: InferenceRequest, dataset_name: str = "bridge_dataset"): try: # Decode and process images images = [] for img_base64 in request.image_base64: if img_base64.startswith("data:image"): img_base64 = img_base64.split(",")[1] img_data = base64.b64decode(img_base64) img = Image.open(io.BytesIO(img_data)).resize((256, 256)) img = np.array(img) images.append(img) # Stack images with batch dimension img_array = np.stack(images)[np.newaxis, ...] # Shape: (1, T, 256, 256, 3) observation = { "image_primary": img_array, "timestep_pad_mask": np.ones((1, len(images)), dtype=bool) # Shape: (1, T) } # Create task and predict actions task_obj = model.create_tasks(texts=[request.task]) actions = model.sample_actions( observation, task_obj, unnormalization_statistics=model.dataset_statistics[dataset_name]["action"], rng=jax.random.PRNGKey(0) ) actions = actions[0] # Remove batch dimension, Shape: (T, action_dim) return {"actions": actions.tolist()} except Exception as e: raise HTTPException(status_code=500, detail=f"Error: {str(e)}")