"""Contains all of the components that can be used with Gradio Interface / Blocks. Along with the docs for each component, you can find the names of example demos that use each component. These demos are located in the `demo` directory.""" from __future__ import annotations import hashlib import os import secrets import shutil import tempfile import urllib.request from enum import Enum from pathlib import Path from typing import TYPE_CHECKING, Any, Callable import aiofiles import numpy as np import requests from fastapi import UploadFile from gradio_client import utils as client_utils from gradio_client.documentation import set_documentation_group from gradio_client.serializing import ( Serializable, ) from PIL import Image as _Image # using _ to minimize namespace pollution from gradio import processing_utils, utils from gradio.blocks import Block, BlockContext from gradio.deprecation import warn_deprecation, warn_style_method_deprecation from gradio.events import ( EventListener, ) from gradio.layouts import Column, Form, Row if TYPE_CHECKING: from typing import TypedDict class DataframeData(TypedDict): headers: list[str] data: list[list[str | int | bool]] set_documentation_group("component") _Image.init() # fixes https://github.com/gradio-app/gradio/issues/2843 class _Keywords(Enum): NO_VALUE = "NO_VALUE" # Used as a sentinel to determine if nothing is provided as a argument for `value` in `Component.update()` FINISHED_ITERATING = "FINISHED_ITERATING" # Used to skip processing of a component's value (needed for generators + state) class Component(Block, Serializable): """ A base class for defining the methods that all gradio components should have. """ def __init__(self, *args, **kwargs): Block.__init__(self, *args, **kwargs) EventListener.__init__(self) def __str__(self): return self.__repr__() def __repr__(self): return f"{self.get_block_name()}" def get_config(self): """ :return: a dictionary with context variables for the javascript file associated with the context """ return { "name": self.get_block_name(), **super().get_config(), } def preprocess(self, x: Any) -> Any: """ Any preprocessing needed to be performed on function input. """ return x def postprocess(self, y): """ Any postprocessing needed to be performed on function output. """ return y def style(self, *args, **kwargs): """ This method is deprecated. Please set these arguments in the Components constructor instead. """ warn_style_method_deprecation() put_deprecated_params_in_box = False if "rounded" in kwargs: warn_deprecation( "'rounded' styling is no longer supported. To round adjacent components together, place them in a Column(variant='box')." ) if isinstance(kwargs["rounded"], (list, tuple)): put_deprecated_params_in_box = True kwargs.pop("rounded") if "margin" in kwargs: warn_deprecation( "'margin' styling is no longer supported. To place adjacent components together without margin, place them in a Column(variant='box')." ) if isinstance(kwargs["margin"], (list, tuple)): put_deprecated_params_in_box = True kwargs.pop("margin") if "border" in kwargs: warn_deprecation( "'border' styling is no longer supported. To place adjacent components in a shared border, place them in a Column(variant='box')." ) kwargs.pop("border") for key in kwargs: warn_deprecation(f"Unknown style parameter: {key}") if ( put_deprecated_params_in_box and isinstance(self.parent, (Row, Column)) and self.parent.variant == "default" ): self.parent.variant = "compact" return self class IOComponent(Component): """ A base class for defining methods that all input/output components should have. """ def __init__( self, *, value: Any = None, label: str | None = None, info: str | None = None, show_label: bool | None = None, container: bool = True, scale: int | None = None, min_width: int | None = None, interactive: bool | None = None, visible: bool = True, elem_id: str | None = None, elem_classes: list[str] | str | None = None, load_fn: Callable | None = None, every: float | None = None, **kwargs, ): self.temp_files: set[str] = set() self.DEFAULT_TEMP_DIR = os.environ.get("GRADIO_TEMP_DIR") or str( Path(tempfile.gettempdir()) / "gradio" ) Component.__init__( self, elem_id=elem_id, elem_classes=elem_classes, visible=visible, **kwargs ) self.label = label self.info = info if not container: if show_label: warn_deprecation("show_label has no effect when container is False.") show_label = False if show_label is None: show_label = True self.show_label = show_label self.container = container if scale is not None and scale != round(scale): warn_deprecation( f"'scale' value should be an integer. Using {scale} will cause issues." ) self.scale = scale self.min_width = min_width self.interactive = interactive # load_event is set in the Blocks.attach_load_events method self.load_event: None | dict[str, Any] = None self.load_event_to_attach = None load_fn, initial_value = self.get_load_fn_and_initial_value(value) self.value = ( initial_value if self._skip_init_processing else self.postprocess(initial_value) ) if callable(load_fn): self.attach_load_event(load_fn, every) @staticmethod def hash_file(file_path: str | Path, chunk_num_blocks: int = 128) -> str: sha1 = hashlib.sha1() with open(file_path, "rb") as f: for chunk in iter(lambda: f.read(chunk_num_blocks * sha1.block_size), b""): sha1.update(chunk) return sha1.hexdigest() @staticmethod def hash_url(url: str, chunk_num_blocks: int = 128) -> str: sha1 = hashlib.sha1() remote = urllib.request.urlopen(url) max_file_size = 100 * 1024 * 1024 # 100MB total_read = 0 while True: data = remote.read(chunk_num_blocks * sha1.block_size) total_read += chunk_num_blocks * sha1.block_size if not data or total_read > max_file_size: break sha1.update(data) return sha1.hexdigest() @staticmethod def hash_bytes(bytes: bytes): sha1 = hashlib.sha1() sha1.update(bytes) return sha1.hexdigest() @staticmethod def hash_base64(base64_encoding: str, chunk_num_blocks: int = 128) -> str: sha1 = hashlib.sha1() for i in range(0, len(base64_encoding), chunk_num_blocks * sha1.block_size): data = base64_encoding[i : i + chunk_num_blocks * sha1.block_size] sha1.update(data.encode("utf-8")) return sha1.hexdigest() def make_temp_copy_if_needed(self, file_path: str | Path) -> str: """Returns a temporary file path for a copy of the given file path if it does not already exist. Otherwise returns the path to the existing temp file.""" temp_dir = self.hash_file(file_path) temp_dir = Path(self.DEFAULT_TEMP_DIR) / temp_dir temp_dir.mkdir(exist_ok=True, parents=True) name = client_utils.strip_invalid_filename_characters(Path(file_path).name) full_temp_file_path = str(utils.abspath(temp_dir / name)) if not Path(full_temp_file_path).exists(): shutil.copy2(file_path, full_temp_file_path) self.temp_files.add(full_temp_file_path) return full_temp_file_path async def save_uploaded_file(self, file: UploadFile, upload_dir: str) -> str: temp_dir = secrets.token_hex( 20 ) # Since the full file is being uploaded anyways, there is no benefit to hashing the file. temp_dir = Path(upload_dir) / temp_dir temp_dir.mkdir(exist_ok=True, parents=True) if file.filename: file_name = Path(file.filename).name name = client_utils.strip_invalid_filename_characters(file_name) else: name = f"tmp{secrets.token_hex(5)}" full_temp_file_path = str(utils.abspath(temp_dir / name)) async with aiofiles.open(full_temp_file_path, "wb") as output_file: while True: content = await file.read(100 * 1024 * 1024) if not content: break await output_file.write(content) return full_temp_file_path def download_temp_copy_if_needed(self, url: str) -> str: """Downloads a file and makes a temporary file path for a copy if does not already exist. Otherwise returns the path to the existing temp file.""" temp_dir = self.hash_url(url) temp_dir = Path(self.DEFAULT_TEMP_DIR) / temp_dir temp_dir.mkdir(exist_ok=True, parents=True) name = client_utils.strip_invalid_filename_characters(Path(url).name) full_temp_file_path = str(utils.abspath(temp_dir / name)) if not Path(full_temp_file_path).exists(): with requests.get(url, stream=True) as r, open( full_temp_file_path, "wb" ) as f: shutil.copyfileobj(r.raw, f) self.temp_files.add(full_temp_file_path) return full_temp_file_path def base64_to_temp_file_if_needed( self, base64_encoding: str, file_name: str | None = None ) -> str: """Converts a base64 encoding to a file and returns the path to the file if the file doesn't already exist. Otherwise returns the path to the existing file. """ temp_dir = self.hash_base64(base64_encoding) temp_dir = Path(self.DEFAULT_TEMP_DIR) / temp_dir temp_dir.mkdir(exist_ok=True, parents=True) guess_extension = client_utils.get_extension(base64_encoding) if file_name: file_name = client_utils.strip_invalid_filename_characters(file_name) elif guess_extension: file_name = f"file.{guess_extension}" else: file_name = "file" full_temp_file_path = str(utils.abspath(temp_dir / file_name)) # type: ignore if not Path(full_temp_file_path).exists(): data, _ = client_utils.decode_base64_to_binary(base64_encoding) with open(full_temp_file_path, "wb") as fb: fb.write(data) self.temp_files.add(full_temp_file_path) return full_temp_file_path def pil_to_temp_file(self, img: _Image.Image, dir: str, format="png") -> str: bytes_data = processing_utils.encode_pil_to_bytes(img, format) temp_dir = Path(dir) / self.hash_bytes(bytes_data) temp_dir.mkdir(exist_ok=True, parents=True) filename = str(temp_dir / f"image.{format}") img.save(filename, pnginfo=processing_utils.get_pil_metadata(img)) return filename def img_array_to_temp_file(self, arr: np.ndarray, dir: str) -> str: pil_image = _Image.fromarray( processing_utils._convert(arr, np.uint8, force_copy=False) ) return self.pil_to_temp_file(pil_image, dir, format="png") def audio_to_temp_file(self, data: np.ndarray, sample_rate: int, format: str): temp_dir = Path(self.DEFAULT_TEMP_DIR) / self.hash_bytes(data.tobytes()) temp_dir.mkdir(exist_ok=True, parents=True) filename = str(temp_dir / f"audio.{format}") processing_utils.audio_to_file(sample_rate, data, filename, format=format) return filename def file_bytes_to_file(self, data: bytes, file_name: str): path = Path(self.DEFAULT_TEMP_DIR) / self.hash_bytes(data) path.mkdir(exist_ok=True, parents=True) path = path / Path(file_name).name path.write_bytes(data) return path def get_config(self): config = { "label": self.label, "show_label": self.show_label, "container": self.container, "scale": self.scale, "min_width": self.min_width, "interactive": self.interactive, **super().get_config(), } if self.info: config["info"] = self.info return config @staticmethod def get_load_fn_and_initial_value(value): if callable(value): initial_value = value() load_fn = value else: initial_value = value load_fn = None return load_fn, initial_value def attach_load_event(self, callable: Callable, every: float | None): """Add a load event that runs `callable`, optionally every `every` seconds.""" self.load_event_to_attach = (callable, every) def as_example(self, input_data): """Return the input data in a way that can be displayed by the examples dataset component in the front-end.""" return input_data class FormComponent: def get_expected_parent(self) -> type[Form] | None: if getattr(self, "container", None) is False: return None return Form def component(cls_name: str) -> Component: obj = utils.component_or_layout_class(cls_name)() if isinstance(obj, BlockContext): raise ValueError(f"Invalid component: {obj.__class__}") return obj def get_component_instance( comp: str | dict | Component, render: bool | None = None ) -> Component: """ Returns a component instance from a string, dict, or Component object. Parameters: comp: the component to instantiate. If a string, must be the name of a component, e.g. "dropdown". If a dict, must have a "name" key, e.g. {"name": "dropdown", "choices": ["a", "b"]}. If a Component object, will be returned as is. render: whether to render the component. If True, renders the component (if not already rendered). If False, *unrenders* the component (if already rendered) -- this is useful when constructing an Interface or ChatInterface inside of a Blocks. If None, does not render or unrender the component. """ if isinstance(comp, str): component_obj = component(comp) elif isinstance(comp, dict): name = comp.pop("name") component_cls = utils.component_or_layout_class(name) component_obj = component_cls(**comp) if isinstance(component_obj, BlockContext): raise ValueError(f"Invalid component: {name}") elif isinstance(comp, Component): component_obj = comp else: raise ValueError( f"Component must provided as a `str` or `dict` or `Component` but is {comp}" ) if render and not component_obj.is_rendered: component_obj.render() elif render is False and component_obj.is_rendered: component_obj.unrender() return component_obj