Spaces:
Running
Running
from octo.model.octo_model import OctoModel | |
from PIL import Image | |
import numpy as np | |
import jax | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
import os | |
import io | |
import base64 | |
from typing import List | |
# Set JAX to use CPU (adjust to GPU if available) | |
os.environ['JAX_PLATFORMS'] = 'cpu' | |
# Load Octo 1.5 model globally | |
model = OctoModel.load_pretrained("hf://rail-berkeley/octo-small-1.5") | |
# Initialize FastAPI app | |
app = FastAPI(title="Octo 1.5 Inference API") | |
# Request body model | |
class InferenceRequest(BaseModel): | |
image_base64: List[str] # List of base64-encoded images | |
task: str = "pick up the fork" # Default task | |
# Health check endpoint | |
async def health_check(): | |
return {"status": "healthy"} | |
# Inference endpoint | |
async def predict(request: InferenceRequest, dataset_name: str = "bridge_dataset"): | |
try: | |
# Decode and process images | |
images = [] | |
for img_base64 in request.image_base64: | |
if img_base64.startswith("data:image"): | |
img_base64 = img_base64.split(",")[1] | |
img_data = base64.b64decode(img_base64) | |
img = Image.open(io.BytesIO(img_data)).resize((256, 256)) | |
img = np.array(img) | |
images.append(img) | |
# Stack images with batch dimension | |
img_array = np.stack(images)[np.newaxis, ...] # Shape: (1, T, 256, 256, 3) | |
observation = { | |
"image_primary": img_array, | |
"timestep_pad_mask": np.ones((1, len(images)), dtype=bool) # Shape: (1, T) | |
} | |
# Create task and predict actions | |
task_obj = model.create_tasks(texts=[request.task]) | |
actions = model.sample_actions( | |
observation, | |
task_obj, | |
unnormalization_statistics=model.dataset_statistics[dataset_name]["action"], | |
rng=jax.random.PRNGKey(0) | |
) | |
actions = actions[0] # Remove batch dimension, Shape: (T, action_dim) | |
return {"actions": actions.tolist()} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error: {str(e)}") |