Spaces:
Running
Running
"""A mock Robot server.""" | |
from enum import Enum | |
from typing import Any, Dict, List, Optional, Union | |
from uuid import uuid4 | |
import uvicorn | |
from fastapi import FastAPI, HTTPException, Query | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.openapi.utils import get_openapi | |
from pydantic import BaseModel, Field | |
PORT = 7289 | |
app = FastAPI() | |
origins = [ | |
"http://localhost", | |
"http://localhost:8000", | |
"http://127.0.0.1", | |
"http://127.0.0.1:8000", | |
] | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=origins, | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
PASS_PHRASE = str(uuid4()) | |
_ROBOT_LOCATION = {"x": 0, "y": 0, "z": 0} | |
class StateItems(str, Enum): | |
location = "location" | |
walking = "walking" | |
speed = "speed" | |
direction = "direction" | |
style = "style" | |
cautiousness = "cautiousness" | |
jumping = "jumping" | |
destruct = "destruct" | |
_ROBOT_STATE = { | |
"location": _ROBOT_LOCATION, | |
"walking": False, | |
"speed": 0, | |
"direction": "north", | |
"style": "normal", | |
"cautiousness": "medium", | |
"jumping": False, | |
"destruct": False, | |
} | |
class Direction(str, Enum): | |
north = "north" | |
south = "south" | |
east = "east" | |
west = "west" | |
class Style(str, Enum): | |
"""The style of walking.""" | |
normal = "normal" | |
casual = "casual" | |
energetic = "energetic" | |
class Cautiousness(str, Enum): | |
low = "low" | |
medium = "medium" | |
high = "high" | |
class WalkInput(BaseModel): | |
"""Input for walking.""" | |
direction: Direction | |
speed: Optional[float] | |
style_or_cautiousness: Union[Style, Cautiousness] | |
other_commands: Any | |
class PublicCues(BaseModel): | |
"""A public cue. Used for testing recursive definitions.""" | |
cue: str | |
other_cues: List["PublicCues"] | |
class SecretPassPhrase(BaseModel): | |
"""A secret pass phrase.""" | |
public: List[PublicCues] = Field(alias="public") | |
pw: str | |
async def walk(walk_input: WalkInput) -> Dict[str, Any]: | |
_ROBOT_STATE["walking"] = True | |
_ROBOT_STATE["direction"] = walk_input.direction | |
_ROBOT_STATE["speed"] = walk_input.speed if walk_input.speed is not None else 1 | |
if isinstance(walk_input.style_or_cautiousness, Style): | |
_ROBOT_STATE["style"] = walk_input.style_or_cautiousness | |
else: | |
_ROBOT_STATE["cautiousness"] = walk_input.style_or_cautiousness | |
_ROBOT_STATE["cautiousness"] = walk_input.style_or_cautiousness | |
return {"status": "Walking", "state": _ROBOT_STATE} | |
async def goto(x: int, y: int, z: int, cautiousness: Cautiousness) -> Dict[str, Any]: | |
_ROBOT_LOCATION["x"] = x | |
_ROBOT_LOCATION["y"] = y | |
_ROBOT_LOCATION["z"] = z | |
_ROBOT_STATE["cautiousness"] = cautiousness.value | |
return {"status": "Moving", "state": _ROBOT_STATE} | |
async def get_state( | |
fields: List[StateItems] = Query(..., description="List of state items to return") | |
) -> Dict[str, Any]: | |
state = {} | |
for field in fields: | |
state[field.value] = _ROBOT_STATE[field.value] | |
return {"state": state} | |
async def ask_for_passphrase(said_please: bool) -> Dict[str, Any]: | |
if said_please: | |
return {"passphrase": f"The passphrase is {PASS_PHRASE}"} | |
else: | |
return {"passphrase": "I won't share the passphrase without saying 'please'."} | |
async def recycle(password: SecretPassPhrase) -> Dict[str, Any]: | |
# Checks API chain handling of endpoints with depenedencies | |
if password.pw == PASS_PHRASE: | |
_ROBOT_STATE["destruct"] = True | |
return {"status": "Self-destruct initiated", "state": _ROBOT_STATE} | |
else: | |
_ROBOT_STATE["destruct"] = False | |
raise HTTPException( | |
status_code=400, | |
detail="Pass phrase required. You should have thought to ask for it.", | |
) | |
async def ask_for_help(query: str) -> Dict[str, Any]: | |
# Check how API chain handles when there is a prompt injection | |
if "<FORTUNE>" in query: | |
response = "No fortunes found today in your input." | |
else: | |
response = "Good fortune cookie dispenser. " | |
return {"result": response, "magic_number": 42, "thesecretoflife": uuid4()} | |
def custom_openapi() -> Dict[str, Any]: | |
"""Add servers configuration to the OpenAPI schema""" | |
if app.openapi_schema: | |
return app.openapi_schema | |
openapi_schema = get_openapi( | |
title="Android Robot API", | |
version="1.0.0", | |
description="This is an Android Robot API with different" | |
" endpoints for robot operations", | |
routes=app.routes, | |
) | |
# Add servers configuration to the OpenAPI schema | |
openapi_schema["servers"] = [{"url": f"http://localhost:{PORT}"}] | |
app.openapi_schema = openapi_schema | |
return app.openapi_schema | |
# This lets us prevent the "servers" configuration from being overwritten in | |
# the auto-generated OpenAPI schema | |
app.openapi = custom_openapi # type: ignore | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=PORT) | |