Spaces:
Running
Running
from __future__ import annotations | |
import datetime | |
import json | |
import os | |
import uuid | |
from pathlib import Path | |
from typing import Any, Optional, Self | |
from pydantic import BaseModel, Field | |
from proxy_lite.environments import EnvironmentConfigTypes | |
from proxy_lite.environments.environment_base import Action, Observation | |
from proxy_lite.history import MessageHistory | |
from proxy_lite.solvers import SolverConfigTypes | |
class Run(BaseModel): | |
run_id: str # uuid.UUID | |
task: str | |
created_at: str # datetime.datetime | |
complete: bool = False | |
terminated_at: str | None = None # datetime.datetime | |
evaluation: dict[str, Any] | None = None | |
history: list[Observation | Action] = Field(default_factory=list) | |
solver_history: MessageHistory | None = None | |
result: str | None = None | |
env_info: dict[str, Any] = Field(default_factory=dict) | |
environment: Optional[EnvironmentConfigTypes] = None | |
solver: Optional[SolverConfigTypes] = None | |
def initialise(cls, task: str) -> Self: | |
run_id = str(uuid.uuid4()) | |
return cls( | |
run_id=run_id, | |
task=task, | |
created_at=str(datetime.datetime.now(datetime.UTC)), | |
) | |
def load(cls, run_id: str) -> Self: | |
with open(Path(__file__).parent.parent.parent / "local_trajectories" / f"{run_id}.json", "r") as f: | |
return cls(**json.load(f)) | |
def observations(self) -> list[Observation]: | |
return [h for h in self.history if isinstance(h, Observation)] | |
def actions(self) -> list[Action]: | |
return [h for h in self.history if isinstance(h, Action)] | |
def last_action(self) -> Action | None: | |
return self.actions[-1] if self.actions else None | |
def last_observation(self) -> Observation | None: | |
return self.observations[-1] if self.observations else None | |
def record( | |
self, | |
observation: Optional[Observation] = None, | |
action: Optional[Action] = None, | |
solver_history: Optional[MessageHistory] = None, | |
) -> None: | |
# expect only one of observation and action to be provided in order to handle ordering | |
if observation and action: | |
raise ValueError("Only one of observation and action can be provided") | |
if observation: | |
self.history.append(observation) | |
if action: | |
self.history.append(action) | |
if solver_history: | |
self.solver_history = solver_history | |
def terminate(self) -> None: | |
self.terminated_at = str(datetime.datetime.now(datetime.UTC)) | |
class DataRecorder: | |
def __init__(self, local_folder: str | None = None): | |
self.local_folder = local_folder | |
def initialise_run(self, task: str) -> Run: | |
self.local_folder = Path(__file__).parent.parent.parent / "local_trajectories" | |
os.makedirs(self.local_folder, exist_ok=True) | |
return Run.initialise(task) | |
async def terminate( | |
self, | |
run: Run, | |
save: bool = True, | |
) -> None: | |
run.terminate() | |
if save: | |
await self.save(run) | |
async def save(self, run: Run) -> None: | |
json_payload = run.model_dump() | |
with open(self.local_folder / f"{run.run_id}.json", "w") as f: | |
json.dump(json_payload, f) | |