Octo-1.5-Small / app.py
Nirav-Madhani's picture
Update app.py
f8cb635 verified
from octo.model.octo_model import OctoModel
from PIL import Image
import numpy as np
import jax
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import os
import io
import base64
from typing import List
# Set JAX to use CPU (adjust to GPU if available)
os.environ['JAX_PLATFORMS'] = 'cpu'
# Load Octo 1.5 model globally
model = OctoModel.load_pretrained("hf://rail-berkeley/octo-small-1.5")
# Initialize FastAPI app
app = FastAPI(title="Octo 1.5 Inference API")
# Request body model
class InferenceRequest(BaseModel):
image_base64: List[str] # List of base64-encoded images
task: str = "pick up the fork" # Default task
# Health check endpoint
@app.get("/health")
async def health_check():
return {"status": "healthy"}
# Inference endpoint
@app.post("/predict")
async def predict(request: InferenceRequest, dataset_name: str = "bridge_dataset"):
try:
# Decode and process images
images = []
for img_base64 in request.image_base64:
if img_base64.startswith("data:image"):
img_base64 = img_base64.split(",")[1]
img_data = base64.b64decode(img_base64)
img = Image.open(io.BytesIO(img_data)).resize((256, 256))
img = np.array(img)
images.append(img)
# Stack images with batch dimension
img_array = np.stack(images)[np.newaxis, ...] # Shape: (1, T, 256, 256, 3)
observation = {
"image_primary": img_array,
"timestep_pad_mask": np.ones((1, len(images)), dtype=bool) # Shape: (1, T)
}
# Create task and predict actions
task_obj = model.create_tasks(texts=[request.task])
actions = model.sample_actions(
observation,
task_obj,
unnormalization_statistics=model.dataset_statistics[dataset_name]["action"],
rng=jax.random.PRNGKey(0)
)
actions = actions[0] # Remove batch dimension, Shape: (T, action_dim)
return {"actions": actions.tolist()}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error: {str(e)}")