Spaces:
Sleeping
Sleeping
Nirav Madhani
commited on
Commit
·
0558e79
1
Parent(s):
ab273a0
Flash server
Browse files- Dockerfile +27 -3
- app.py +65 -0
- init_model.py +9 -0
- main.py +13 -6
- test_api.py +54 -0
Dockerfile
CHANGED
@@ -1,11 +1,35 @@
|
|
1 |
-
FROM python:3.10
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
WORKDIR /app
|
|
|
|
|
4 |
RUN git clone https://github.com/octo-models/octo.git
|
5 |
WORKDIR /app/octo
|
|
|
|
|
6 |
RUN pip3 install -e .
|
7 |
RUN pip3 install -r requirements.txt
|
8 |
RUN pip3 install --upgrade "jax[cuda11_pip]==0.4.20" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
9 |
-
RUN
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
COPY main.py /app/octo
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.10-slim
|
2 |
+
|
3 |
+
# Update package list and install git
|
4 |
+
RUN apt-get update && \
|
5 |
+
apt-get install -y git && \
|
6 |
+
rm -rf /var/lib/apt/lists/* # Clean up to reduce image size
|
7 |
|
8 |
WORKDIR /app
|
9 |
+
|
10 |
+
# Clone the octo repository
|
11 |
RUN git clone https://github.com/octo-models/octo.git
|
12 |
WORKDIR /app/octo
|
13 |
+
|
14 |
+
# Install dependencies
|
15 |
RUN pip3 install -e .
|
16 |
RUN pip3 install -r requirements.txt
|
17 |
RUN pip3 install --upgrade "jax[cuda11_pip]==0.4.20" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
18 |
+
RUN pip3 install scipy==1.10.1
|
19 |
+
|
20 |
+
# Install FastAPI and Uvicorn for the API
|
21 |
+
RUN pip3 install fastapi uvicorn
|
22 |
+
|
23 |
+
# Copy and run the model initialization script to cache the model
|
24 |
+
COPY init_model.py /app/octo
|
25 |
+
RUN python init_model.py
|
26 |
+
|
27 |
+
# Copy the original main.py and the API app.py
|
28 |
COPY main.py /app/octo
|
29 |
+
COPY app.py /app/octo
|
30 |
+
|
31 |
+
# Expose port 8000 for the API
|
32 |
+
EXPOSE 8000
|
33 |
+
|
34 |
+
# Run the API with Uvicorn
|
35 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]
|
app.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from octo.model.octo_model import OctoModel
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
import jax
|
5 |
+
from fastapi import FastAPI, HTTPException
|
6 |
+
from pydantic import BaseModel
|
7 |
+
import os
|
8 |
+
import io
|
9 |
+
import base64
|
10 |
+
|
11 |
+
# Set JAX to use CPU platform (adjust if GPU is needed)
|
12 |
+
os.environ['JAX_PLATFORMS'] = 'cpu'
|
13 |
+
|
14 |
+
# Load the model once globally (assumes it's cached locally)
|
15 |
+
model = OctoModel.load_pretrained("hf://rail-berkeley/octo-small-1.5")
|
16 |
+
|
17 |
+
# Initialize FastAPI app
|
18 |
+
app = FastAPI(title="Octo Model Inference API")
|
19 |
+
|
20 |
+
# Define request body model
|
21 |
+
class InferenceRequest(BaseModel):
|
22 |
+
image_base64: str # Base64-encoded image string
|
23 |
+
task: str = "pick up the fork" # Default task
|
24 |
+
|
25 |
+
# Health check endpoint
|
26 |
+
@app.get("/health")
|
27 |
+
async def health_check():
|
28 |
+
return {"status": "healthy"}
|
29 |
+
|
30 |
+
# Inference endpoint
|
31 |
+
@app.post("/predict")
|
32 |
+
async def predict(request: InferenceRequest):
|
33 |
+
try:
|
34 |
+
# Decode base64 image
|
35 |
+
img_base64 = request.image_base64
|
36 |
+
if img_base64.startswith("data:image"):
|
37 |
+
img_base64 = img_base64.split(",")[1]
|
38 |
+
|
39 |
+
img_data = base64.b64decode(img_base64)
|
40 |
+
img = Image.open(io.BytesIO(img_data)).resize((256, 256))
|
41 |
+
img = np.array(img)
|
42 |
+
|
43 |
+
# Add batch and time horizon dimensions
|
44 |
+
img = img[np.newaxis, np.newaxis, ...] # Shape: (1, 1, 256, 256, 3)
|
45 |
+
observation = {
|
46 |
+
"image_primary": img,
|
47 |
+
"timestep_pad_mask": np.array([[True]])
|
48 |
+
}
|
49 |
+
|
50 |
+
# Create task and predict actions
|
51 |
+
task_obj = model.create_tasks(texts=[request.task])
|
52 |
+
actions = model.sample_actions(
|
53 |
+
observation,
|
54 |
+
task_obj,
|
55 |
+
unnormalization_statistics=model.dataset_statistics["bridge_dataset"]["action"],
|
56 |
+
rng=jax.random.PRNGKey(0)
|
57 |
+
)
|
58 |
+
actions = actions[0]
|
59 |
+
|
60 |
+
# Convert NumPy array to list for JSON response
|
61 |
+
actions_list = actions.tolist()
|
62 |
+
|
63 |
+
return {"actions": actions_list}
|
64 |
+
except Exception as e:
|
65 |
+
raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")
|
init_model.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from octo.model.octo_model import OctoModel
|
2 |
+
import os
|
3 |
+
|
4 |
+
# Set JAX to CPU (consistent with your setup)
|
5 |
+
os.environ['JAX_PLATFORMS'] = 'cpu'
|
6 |
+
|
7 |
+
# Load the model to cache it
|
8 |
+
model = OctoModel.load_pretrained("hf://rail-berkeley/octo-small-1.5")
|
9 |
+
print("Model downloaded and cached successfully.")
|
main.py
CHANGED
@@ -17,10 +17,17 @@ img = np.array(Image.open(requests.get(IMAGE_URL, stream=True).raw).resize((256,
|
|
17 |
img = img[np.newaxis,np.newaxis,...]
|
18 |
observation = {"image_primary": img, "timestep_pad_mask": np.array([[True]])}
|
19 |
task = model.create_tasks(texts=["pick up the fork"])
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
)
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
img = img[np.newaxis,np.newaxis,...]
|
18 |
observation = {"image_primary": img, "timestep_pad_mask": np.array([[True]])}
|
19 |
task = model.create_tasks(texts=["pick up the fork"])
|
20 |
+
norm_actions = model.sample_actions(observation, task, rng=jax.random.PRNGKey(0))
|
21 |
+
norm_actions = norm_actions[0] # remove batch
|
22 |
+
actions = (
|
23 |
+
norm_actions * model.dataset_statistics["bridge_dataset"]['action']['std']
|
24 |
+
+ model.dataset_statistics["bridge_dataset"]['action']['mean']
|
25 |
)
|
26 |
+
actions = np.concatenate(
|
27 |
+
(
|
28 |
+
steps[step+1]['action']['world_vector'],
|
29 |
+
steps[step+1]['action']['rotation_delta'],
|
30 |
+
np.array(steps[step+1]['action']['open_gripper']).astype(np.float32)[None]
|
31 |
+
), axis=-1
|
32 |
+
)
|
33 |
+
print(actions)
|
test_api.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import base64
|
3 |
+
from PIL import Image
|
4 |
+
import io
|
5 |
+
|
6 |
+
# API endpoint URL (adjust if running on a different host/port)
|
7 |
+
API_URL = "http://localhost:8000/predict"
|
8 |
+
|
9 |
+
# Example image URL from main.py
|
10 |
+
IMAGE_URL = "https://rail.eecs.berkeley.edu/datasets/bridge_release/raw/bridge_data_v2/datacol2_toykitchen7/drawer_pnp/01/2023-04-19_09-18-15/raw/traj_group0/traj0/images0/im_12.jpg"
|
11 |
+
TASK_TEXT = "pick up the fork"
|
12 |
+
|
13 |
+
def test_api(image_url=IMAGE_URL, task=TASK_TEXT):
|
14 |
+
try:
|
15 |
+
# Download image from URL
|
16 |
+
response = requests.get(image_url, stream=True)
|
17 |
+
response.raise_for_status() # Check for HTTP errors
|
18 |
+
img = Image.open(response.raw).resize((256, 256))
|
19 |
+
|
20 |
+
# Convert image to base64
|
21 |
+
img_byte_arr = io.BytesIO()
|
22 |
+
img.save(img_byte_arr, format="JPEG") # Save as JPEG (adjust if needed)
|
23 |
+
img_byte_arr = img_byte_arr.getvalue()
|
24 |
+
base64_string = base64.b64encode(img_byte_arr).decode("utf-8")
|
25 |
+
|
26 |
+
# Prepare payload for API
|
27 |
+
payload = {
|
28 |
+
"image_base64": base64_string,
|
29 |
+
"task": task
|
30 |
+
}
|
31 |
+
|
32 |
+
# Send POST request to API
|
33 |
+
api_response = requests.post(API_URL, json=payload)
|
34 |
+
api_response.raise_for_status() # Check for API errors
|
35 |
+
|
36 |
+
# Print the result
|
37 |
+
result = api_response.json()
|
38 |
+
print(f"Task: {task}")
|
39 |
+
print(f"Image URL: {image_url}")
|
40 |
+
print(f"Predicted Actions: {result['actions']}")
|
41 |
+
|
42 |
+
except requests.exceptions.RequestException as e:
|
43 |
+
print(f"Error fetching image or calling API: {e}")
|
44 |
+
except Exception as e:
|
45 |
+
print(f"Unexpected error: {e}")
|
46 |
+
|
47 |
+
if __name__ == "__main__":
|
48 |
+
# Test with default values (same as main.py)
|
49 |
+
test_api()
|
50 |
+
|
51 |
+
# Test with a different URL and task (optional)
|
52 |
+
# Replace with another valid URL if desired
|
53 |
+
print("\nTesting with another URL and task:")
|
54 |
+
test_api(IMAGE_URL, TASK_TEXT)
|