Spaces:
Running
Running
File size: 3,354 Bytes
6a0e448 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
from __future__ import annotations
import datetime
import json
import os
import uuid
from pathlib import Path
from typing import Any, Optional, Self
from pydantic import BaseModel, Field
from proxy_lite.environments import EnvironmentConfigTypes
from proxy_lite.environments.environment_base import Action, Observation
from proxy_lite.history import MessageHistory
from proxy_lite.solvers import SolverConfigTypes
class Run(BaseModel):
run_id: str # uuid.UUID
task: str
created_at: str # datetime.datetime
complete: bool = False
terminated_at: str | None = None # datetime.datetime
evaluation: dict[str, Any] | None = None
history: list[Observation | Action] = Field(default_factory=list)
solver_history: MessageHistory | None = None
result: str | None = None
env_info: dict[str, Any] = Field(default_factory=dict)
environment: Optional[EnvironmentConfigTypes] = None
solver: Optional[SolverConfigTypes] = None
@classmethod
def initialise(cls, task: str) -> Self:
run_id = str(uuid.uuid4())
return cls(
run_id=run_id,
task=task,
created_at=str(datetime.datetime.now(datetime.UTC)),
)
@classmethod
def load(cls, run_id: str) -> Self:
with open(Path(__file__).parent.parent.parent / "local_trajectories" / f"{run_id}.json", "r") as f:
return cls(**json.load(f))
@property
def observations(self) -> list[Observation]:
return [h for h in self.history if isinstance(h, Observation)]
@property
def actions(self) -> list[Action]:
return [h for h in self.history if isinstance(h, Action)]
@property
def last_action(self) -> Action | None:
return self.actions[-1] if self.actions else None
@property
def last_observation(self) -> Observation | None:
return self.observations[-1] if self.observations else None
def record(
self,
observation: Optional[Observation] = None,
action: Optional[Action] = None,
solver_history: Optional[MessageHistory] = None,
) -> None:
# expect only one of observation and action to be provided in order to handle ordering
if observation and action:
raise ValueError("Only one of observation and action can be provided")
if observation:
self.history.append(observation)
if action:
self.history.append(action)
if solver_history:
self.solver_history = solver_history
def terminate(self) -> None:
self.terminated_at = str(datetime.datetime.now(datetime.UTC))
class DataRecorder:
def __init__(self, local_folder: str | None = None):
self.local_folder = local_folder
def initialise_run(self, task: str) -> Run:
self.local_folder = Path(__file__).parent.parent.parent / "local_trajectories"
os.makedirs(self.local_folder, exist_ok=True)
return Run.initialise(task)
async def terminate(
self,
run: Run,
save: bool = True,
) -> None:
run.terminate()
if save:
await self.save(run)
async def save(self, run: Run) -> None:
json_payload = run.model_dump()
with open(self.local_folder / f"{run.run_id}.json", "w") as f:
json.dump(json_payload, f)
|