"""Contains all of the components that can be used with Gradio Interface / Blocks. Along with the docs for each component, you can find the names of example demos that use each component. These demos are located in the `demo` directory.""" from __future__ import annotations import abc import hashlib import json import sys import warnings from abc import ABC, abstractmethod from enum import Enum from pathlib import Path from typing import TYPE_CHECKING, Any, Callable from gradio_client.documentation import set_documentation_group from PIL import Image as _Image # using _ to minimize namespace pollution from gradio import utils from gradio.blocks import Block, BlockContext from gradio.component_meta import ComponentMeta from gradio.data_classes import GradioDataModel from gradio.events import EventListener from gradio.layouts import Form from gradio.processing_utils import move_files_to_cache if TYPE_CHECKING: from typing import TypedDict class DataframeData(TypedDict): headers: list[str] data: list[list[str | int | bool]] set_documentation_group("component") _Image.init() # fixes https://github.com/gradio-app/gradio/issues/2843 class _Keywords(Enum): NO_VALUE = "NO_VALUE" # Used as a sentinel to determine if nothing is provided as a argument for `value` in `Component.update()` FINISHED_ITERATING = "FINISHED_ITERATING" # Used to skip processing of a component's value (needed for generators + state) from gradio.events import Dependency class ComponentBase(ABC, metaclass=ComponentMeta): EVENTS: list[EventListener | str] = [] @abstractmethod def preprocess(self, payload: Any) -> Any: """ Any preprocessing needed to be performed on function input. """ return payload @abstractmethod def postprocess(self, value): """ Any postprocessing needed to be performed on function output. """ return value @abstractmethod def as_example(self, value): """ Return the input data in a way that can be displayed by the examples dataset component in the front-end. For example, only return the name of a file as opposed to a full path. Or get the head of a dataframe. Must be able to be converted to a string to put in the config. """ pass @abstractmethod def api_info(self) -> dict[str, list[str]]: """ The typing information for this component as a dictionary whose values are a list of 2 strings: [Python type, language-agnostic description]. Keys of the dictionary are: raw_input, raw_output, serialized_input, serialized_output """ pass @abstractmethod def example_inputs(self) -> Any: """ The example inputs for this component as a dictionary whose values are example inputs compatible with this component. Keys of the dictionary are: raw, serialized """ pass @abstractmethod def flag(self, payload: Any | GradioDataModel, flag_dir: str | Path = "") -> str: """ Write the component's value to a format that can be stored in a csv or jsonl format for flagging. """ pass @abstractmethod def read_from_flag( self, payload: Any, flag_dir: str | Path | None = None, ) -> GradioDataModel | Any: """ Convert the data from the csv or jsonl file into the component state. """ return payload @property @abstractmethod def skip_api(self): """Whether this component should be skipped from the api return value""" @classmethod def has_event(cls, event: str | EventListener) -> bool: return event in cls.EVENTS @classmethod def get_component_class_id(cls) -> str: module_name = cls.__module__ module_path = sys.modules[module_name].__file__ module_hash = hashlib.md5(f"{cls.__name__}_{module_path}".encode()).hexdigest() return module_hash def server(fn): fn._is_server_fn = True return fn class Component(ComponentBase, Block): """ A base class for defining methods that all input/output components should have. """ def __init__( self, value: Any = None, *, label: str | None = None, info: str | None = None, show_label: bool | None = None, container: bool = True, scale: int | None = None, min_width: int | None = None, interactive: bool | None = None, visible: bool = True, elem_id: str | None = None, elem_classes: list[str] | str | None = None, render: bool = True, load_fn: Callable | None = None, every: float | None = None, ): self.server_fns = [ value for value in self.__class__.__dict__.values() if callable(value) and getattr(value, "_is_server_fn", False) ] # Svelte components expect elem_classes to be a list # If we don't do this, returning a new component for an # update will break the frontend if not elem_classes: elem_classes = [] # This gets overridden when `select` is called self._selectable = False if not hasattr(self, "data_model"): self.data_model: type[GradioDataModel] | None = None Block.__init__( self, elem_id=elem_id, elem_classes=elem_classes, visible=visible, render=render, ) if isinstance(self, StreamingInput): self.check_streamable() self.label = label self.info = info if not container: if show_label: warnings.warn("show_label has no effect when container is False.") show_label = False if show_label is None: show_label = True self.show_label = show_label self.container = container if scale is not None and scale != round(scale): warnings.warn( f"'scale' value should be an integer. Using {scale} will cause issues." ) self.scale = scale self.min_width = min_width self.interactive = interactive # load_event is set in the Blocks.attach_load_events method self.load_event: None | dict[str, Any] = None self.load_event_to_attach: None | tuple[Callable, float | None] = None load_fn, initial_value = self.get_load_fn_and_initial_value(value) initial_value = self.postprocess(initial_value) self.value = move_files_to_cache(initial_value, self, postprocess=True) # type: ignore if callable(load_fn): self.attach_load_event(load_fn, every) self.component_class_id = self.__class__.get_component_class_id() TEMPLATE_DIR = "./templates/" FRONTEND_DIR = "../../frontend/" def get_config(self): config = super().get_config() if self.info: config["info"] = self.info if len(self.server_fns): config["server_fns"] = [fn.__name__ for fn in self.server_fns] config.pop("render", None) return config @property def skip_api(self): return False @staticmethod def get_load_fn_and_initial_value(value): if callable(value): initial_value = value() load_fn = value else: initial_value = value load_fn = None return load_fn, initial_value def __str__(self): return self.__repr__() def __repr__(self): return f"{self.get_block_name()}" def attach_load_event(self, callable: Callable, every: float | None): """Add a load event that runs `callable`, optionally every `every` seconds.""" self.load_event_to_attach = (callable, every) def as_example(self, input_data): """Return the input data in a way that can be displayed by the examples dataset component in the front-end.""" return input_data def api_info(self) -> dict[str, Any]: """ The typing information for this component as a dictionary whose values are a list of 2 strings: [Python type, language-agnostic description]. Keys of the dictionary are: raw_input, raw_output, serialized_input, serialized_output """ if self.data_model is not None: return self.data_model.model_json_schema() raise NotImplementedError( f"The api_info method has not been implemented for {self.get_block_name()}" ) def flag(self, payload: Any, flag_dir: str | Path = "") -> str: """ Write the component's value to a format that can be stored in a csv or jsonl format for flagging. """ if self.data_model: payload = self.data_model.from_json(payload) Path(flag_dir).mkdir(exist_ok=True) return payload.copy_to_dir(flag_dir).model_dump_json() return payload def read_from_flag( self, payload: Any, flag_dir: str | Path | None = None, ): """ Convert the data from the csv or jsonl file into the component state. """ if self.data_model: return self.data_model.from_json(json.loads(payload)) return payload class FormComponent(Component): def get_expected_parent(self) -> type[Form] | None: if getattr(self, "container", None) is False: return None return Form def preprocess(self, payload: Any) -> Any: return payload def postprocess(self, value): return value class StreamingOutput(metaclass=abc.ABCMeta): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.streaming: bool @abc.abstractmethod def stream_output( self, value, output_id: str, first_chunk: bool ) -> tuple[bytes, Any]: pass class StreamingInput(metaclass=abc.ABCMeta): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @abc.abstractmethod def check_streamable(self): """Used to check if streaming is supported given the input.""" pass def component(cls_name: str, render: bool) -> Component: obj = utils.component_or_layout_class(cls_name)(render=render) if isinstance(obj, BlockContext): raise ValueError(f"Invalid component: {obj.__class__}") assert isinstance(obj, Component) return obj def get_component_instance( comp: str | dict | Component, render: bool = False, unrender: bool = False ) -> Component: """ Returns a component instance from a string, dict, or Component object. Parameters: comp: the component to instantiate. If a string, must be the name of a component, e.g. "dropdown". If a dict, must have a "name" key, e.g. {"name": "dropdown", "choices": ["a", "b"]}. If a Component object, will be returned as is. render: whether to render the component. If True, renders the component (if not already rendered). If False, does not do anything. unrender: whether to unrender the component. If True, unrenders the the component (if already rendered) -- this is useful when constructing an Interface or ChatInterface inside of a Blocks. If False, does not do anything. """ if isinstance(comp, str): component_obj = component(comp, render=render) elif isinstance(comp, dict): name = comp.pop("name") component_cls = utils.component_or_layout_class(name) component_obj = component_cls(**comp, render=render) if isinstance(component_obj, BlockContext): raise ValueError(f"Invalid component: {name}") elif isinstance(comp, Component): component_obj = comp else: raise ValueError( f"Component must provided as a `str` or `dict` or `Component` but is {comp}" ) if render and not component_obj.is_rendered: component_obj.render() elif unrender and component_obj.is_rendered: component_obj.unrender() assert isinstance(component_obj, Component) return component_obj