"""Pydantic data models and other dataclasses. This is the only file that uses Optional[] typing syntax instead of | None syntax to work with pydantic""" from __future__ import annotations import pathlib import secrets import shutil from abc import ABC, abstractmethod from enum import Enum, auto from typing import Any, List, Optional, Union from fastapi import Request from gradio_client.utils import traverse from typing_extensions import Literal from . import wasm_utils if not wasm_utils.IS_WASM: from pydantic import BaseModel, RootModel, ValidationError # type: ignore else: # XXX: Currently Pyodide V2 is not available on Pyodide, # so we install V1 for the Wasm version. from typing import Generic, TypeVar from pydantic import BaseModel as BaseModelV1 from pydantic import ValidationError, schema_of # Map V2 method calls to V1 implementations. # Ref: https://docs.pydantic.dev/latest/migration/#changes-to-pydanticbasemodel class BaseModel(BaseModelV1): pass BaseModel.model_dump = BaseModel.dict # type: ignore BaseModel.model_json_schema = BaseModel.schema # type: ignore # RootModel is not available in V1, so we create a dummy class. PydanticUndefined = object() RootModelRootType = TypeVar("RootModelRootType") class RootModel(BaseModel, Generic[RootModelRootType]): root: RootModelRootType def __init__(self, root: RootModelRootType = PydanticUndefined, **data): if data: if root is not PydanticUndefined: raise ValueError( '"RootModel.__init__" accepts either a single positional argument or arbitrary keyword arguments' ) root = data # type: ignore # XXX: No runtime validation is executed. super().__init__(root=root) # type: ignore def dict(self, **kwargs): return super().dict(**kwargs)["root"] @classmethod def schema(cls, **kwargs): # XXX: kwargs are ignored. return schema_of(cls.__fields__["root"].type_) # type: ignore RootModel.model_dump = RootModel.dict # type: ignore RootModel.model_json_schema = RootModel.schema # type: ignore class PredictBody(BaseModel): class Config: arbitrary_types_allowed = True session_hash: Optional[str] = None event_id: Optional[str] = None data: List[Any] event_data: Optional[Any] = None fn_index: Optional[int] = None trigger_id: Optional[int] = None batched: Optional[ bool ] = False # Whether the data is a batch of samples (i.e. called from the queue if batch=True) or a single sample (i.e. called from the UI) request: Optional[ Request ] = None # dictionary of request headers, query parameters, url, etc. (used to to pass in request for queuing) class ResetBody(BaseModel): event_id: str class ComponentServerBody(BaseModel): session_hash: str component_id: int fn_name: str data: Any class InterfaceTypes(Enum): STANDARD = auto() INPUT_ONLY = auto() OUTPUT_ONLY = auto() UNIFIED = auto() class Estimation(BaseModel): rank: Optional[int] = None queue_size: int rank_eta: Optional[float] = None class ProgressUnit(BaseModel): index: Optional[int] = None length: Optional[int] = None unit: Optional[str] = None progress: Optional[float] = None desc: Optional[str] = None class Progress(BaseModel): progress_data: List[ProgressUnit] = [] class LogMessage(BaseModel): log: str level: Literal["info", "warning"] class GradioBaseModel(ABC): def copy_to_dir(self, dir: str | pathlib.Path) -> GradioDataModel: assert isinstance(self, (BaseModel, RootModel)) if isinstance(dir, str): dir = pathlib.Path(dir) # TODO: Making sure path is unique should be done in caller def unique_copy(obj: dict): data = FileData(**obj) return data._copy_to_dir( str(pathlib.Path(dir / secrets.token_hex(10))) ).model_dump() return self.__class__.from_json( x=traverse( self.model_dump(), unique_copy, FileData.is_file_data, ) ) @classmethod @abstractmethod def from_json(cls, x) -> GradioDataModel: pass class GradioModel(GradioBaseModel, BaseModel): @classmethod def from_json(cls, x) -> GradioModel: return cls(**x) class GradioRootModel(GradioBaseModel, RootModel): @classmethod def from_json(cls, x) -> GradioRootModel: return cls(root=x) GradioDataModel = Union[GradioModel, GradioRootModel] class FileData(GradioModel): path: str # server filepath url: Optional[str] = None # normalised server url size: Optional[int] = None # size in bytes orig_name: Optional[str] = None # original filename mime_type: Optional[str] = None @property def is_none(self): return all( f is None for f in [ self.path, self.url, self.size, self.orig_name, self.mime_type, ] ) @classmethod def from_path(cls, path: str) -> FileData: return cls(path=path) def _copy_to_dir(self, dir: str) -> FileData: pathlib.Path(dir).mkdir(exist_ok=True) new_obj = dict(self) assert self.path new_name = shutil.copy(self.path, dir) new_obj["path"] = new_name return self.__class__(**new_obj) @classmethod def is_file_data(cls, obj: Any): if isinstance(obj, dict): try: return not FileData(**obj).is_none except (TypeError, ValidationError): return False return False