OpkaGames's picture
Upload folder using huggingface_hub
870ab6b
"""Contains all of the components that can be used with Gradio Interface / Blocks.
Along with the docs for each component, you can find the names of example demos that use
each component. These demos are located in the `demo` directory."""
from __future__ import annotations
import hashlib
import os
import secrets
import shutil
import tempfile
import urllib.request
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable
import aiofiles
import numpy as np
import requests
from fastapi import UploadFile
from gradio_client import utils as client_utils
from gradio_client.documentation import set_documentation_group
from gradio_client.serializing import (
Serializable,
)
from PIL import Image as _Image # using _ to minimize namespace pollution
from gradio import processing_utils, utils
from gradio.blocks import Block, BlockContext
from gradio.deprecation import warn_deprecation, warn_style_method_deprecation
from gradio.events import (
EventListener,
)
from gradio.layouts import Column, Form, Row
if TYPE_CHECKING:
from typing import TypedDict
class DataframeData(TypedDict):
headers: list[str]
data: list[list[str | int | bool]]
set_documentation_group("component")
_Image.init() # fixes https://github.com/gradio-app/gradio/issues/2843
class _Keywords(Enum):
NO_VALUE = "NO_VALUE" # Used as a sentinel to determine if nothing is provided as a argument for `value` in `Component.update()`
FINISHED_ITERATING = "FINISHED_ITERATING" # Used to skip processing of a component's value (needed for generators + state)
class Component(Block, Serializable):
"""
A base class for defining the methods that all gradio components should have.
"""
def __init__(self, *args, **kwargs):
Block.__init__(self, *args, **kwargs)
EventListener.__init__(self)
def __str__(self):
return self.__repr__()
def __repr__(self):
return f"{self.get_block_name()}"
def get_config(self):
"""
:return: a dictionary with context variables for the javascript file associated with the context
"""
return {
"name": self.get_block_name(),
**super().get_config(),
}
def preprocess(self, x: Any) -> Any:
"""
Any preprocessing needed to be performed on function input.
"""
return x
def postprocess(self, y):
"""
Any postprocessing needed to be performed on function output.
"""
return y
def style(self, *args, **kwargs):
"""
This method is deprecated. Please set these arguments in the Components constructor instead.
"""
warn_style_method_deprecation()
put_deprecated_params_in_box = False
if "rounded" in kwargs:
warn_deprecation(
"'rounded' styling is no longer supported. To round adjacent components together, place them in a Column(variant='box')."
)
if isinstance(kwargs["rounded"], (list, tuple)):
put_deprecated_params_in_box = True
kwargs.pop("rounded")
if "margin" in kwargs:
warn_deprecation(
"'margin' styling is no longer supported. To place adjacent components together without margin, place them in a Column(variant='box')."
)
if isinstance(kwargs["margin"], (list, tuple)):
put_deprecated_params_in_box = True
kwargs.pop("margin")
if "border" in kwargs:
warn_deprecation(
"'border' styling is no longer supported. To place adjacent components in a shared border, place them in a Column(variant='box')."
)
kwargs.pop("border")
for key in kwargs:
warn_deprecation(f"Unknown style parameter: {key}")
if (
put_deprecated_params_in_box
and isinstance(self.parent, (Row, Column))
and self.parent.variant == "default"
):
self.parent.variant = "compact"
return self
class IOComponent(Component):
"""
A base class for defining methods that all input/output components should have.
"""
def __init__(
self,
*,
value: Any = None,
label: str | None = None,
info: str | None = None,
show_label: bool | None = None,
container: bool = True,
scale: int | None = None,
min_width: int | None = None,
interactive: bool | None = None,
visible: bool = True,
elem_id: str | None = None,
elem_classes: list[str] | str | None = None,
load_fn: Callable | None = None,
every: float | None = None,
**kwargs,
):
self.temp_files: set[str] = set()
self.DEFAULT_TEMP_DIR = os.environ.get("GRADIO_TEMP_DIR") or str(
Path(tempfile.gettempdir()) / "gradio"
)
Component.__init__(
self, elem_id=elem_id, elem_classes=elem_classes, visible=visible, **kwargs
)
self.label = label
self.info = info
if not container:
if show_label:
warn_deprecation("show_label has no effect when container is False.")
show_label = False
if show_label is None:
show_label = True
self.show_label = show_label
self.container = container
if scale is not None and scale != round(scale):
warn_deprecation(
f"'scale' value should be an integer. Using {scale} will cause issues."
)
self.scale = scale
self.min_width = min_width
self.interactive = interactive
# load_event is set in the Blocks.attach_load_events method
self.load_event: None | dict[str, Any] = None
self.load_event_to_attach = None
load_fn, initial_value = self.get_load_fn_and_initial_value(value)
self.value = (
initial_value
if self._skip_init_processing
else self.postprocess(initial_value)
)
if callable(load_fn):
self.attach_load_event(load_fn, every)
@staticmethod
def hash_file(file_path: str | Path, chunk_num_blocks: int = 128) -> str:
sha1 = hashlib.sha1()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(chunk_num_blocks * sha1.block_size), b""):
sha1.update(chunk)
return sha1.hexdigest()
@staticmethod
def hash_url(url: str, chunk_num_blocks: int = 128) -> str:
sha1 = hashlib.sha1()
remote = urllib.request.urlopen(url)
max_file_size = 100 * 1024 * 1024 # 100MB
total_read = 0
while True:
data = remote.read(chunk_num_blocks * sha1.block_size)
total_read += chunk_num_blocks * sha1.block_size
if not data or total_read > max_file_size:
break
sha1.update(data)
return sha1.hexdigest()
@staticmethod
def hash_bytes(bytes: bytes):
sha1 = hashlib.sha1()
sha1.update(bytes)
return sha1.hexdigest()
@staticmethod
def hash_base64(base64_encoding: str, chunk_num_blocks: int = 128) -> str:
sha1 = hashlib.sha1()
for i in range(0, len(base64_encoding), chunk_num_blocks * sha1.block_size):
data = base64_encoding[i : i + chunk_num_blocks * sha1.block_size]
sha1.update(data.encode("utf-8"))
return sha1.hexdigest()
def make_temp_copy_if_needed(self, file_path: str | Path) -> str:
"""Returns a temporary file path for a copy of the given file path if it does
not already exist. Otherwise returns the path to the existing temp file."""
temp_dir = self.hash_file(file_path)
temp_dir = Path(self.DEFAULT_TEMP_DIR) / temp_dir
temp_dir.mkdir(exist_ok=True, parents=True)
name = client_utils.strip_invalid_filename_characters(Path(file_path).name)
full_temp_file_path = str(utils.abspath(temp_dir / name))
if not Path(full_temp_file_path).exists():
shutil.copy2(file_path, full_temp_file_path)
self.temp_files.add(full_temp_file_path)
return full_temp_file_path
async def save_uploaded_file(self, file: UploadFile, upload_dir: str) -> str:
temp_dir = secrets.token_hex(
20
) # Since the full file is being uploaded anyways, there is no benefit to hashing the file.
temp_dir = Path(upload_dir) / temp_dir
temp_dir.mkdir(exist_ok=True, parents=True)
if file.filename:
file_name = Path(file.filename).name
name = client_utils.strip_invalid_filename_characters(file_name)
else:
name = f"tmp{secrets.token_hex(5)}"
full_temp_file_path = str(utils.abspath(temp_dir / name))
async with aiofiles.open(full_temp_file_path, "wb") as output_file:
while True:
content = await file.read(100 * 1024 * 1024)
if not content:
break
await output_file.write(content)
return full_temp_file_path
def download_temp_copy_if_needed(self, url: str) -> str:
"""Downloads a file and makes a temporary file path for a copy if does not already
exist. Otherwise returns the path to the existing temp file."""
temp_dir = self.hash_url(url)
temp_dir = Path(self.DEFAULT_TEMP_DIR) / temp_dir
temp_dir.mkdir(exist_ok=True, parents=True)
name = client_utils.strip_invalid_filename_characters(Path(url).name)
full_temp_file_path = str(utils.abspath(temp_dir / name))
if not Path(full_temp_file_path).exists():
with requests.get(url, stream=True) as r, open(
full_temp_file_path, "wb"
) as f:
shutil.copyfileobj(r.raw, f)
self.temp_files.add(full_temp_file_path)
return full_temp_file_path
def base64_to_temp_file_if_needed(
self, base64_encoding: str, file_name: str | None = None
) -> str:
"""Converts a base64 encoding to a file and returns the path to the file if
the file doesn't already exist. Otherwise returns the path to the existing file.
"""
temp_dir = self.hash_base64(base64_encoding)
temp_dir = Path(self.DEFAULT_TEMP_DIR) / temp_dir
temp_dir.mkdir(exist_ok=True, parents=True)
guess_extension = client_utils.get_extension(base64_encoding)
if file_name:
file_name = client_utils.strip_invalid_filename_characters(file_name)
elif guess_extension:
file_name = f"file.{guess_extension}"
else:
file_name = "file"
full_temp_file_path = str(utils.abspath(temp_dir / file_name)) # type: ignore
if not Path(full_temp_file_path).exists():
data, _ = client_utils.decode_base64_to_binary(base64_encoding)
with open(full_temp_file_path, "wb") as fb:
fb.write(data)
self.temp_files.add(full_temp_file_path)
return full_temp_file_path
def pil_to_temp_file(self, img: _Image.Image, dir: str, format="png") -> str:
bytes_data = processing_utils.encode_pil_to_bytes(img, format)
temp_dir = Path(dir) / self.hash_bytes(bytes_data)
temp_dir.mkdir(exist_ok=True, parents=True)
filename = str(temp_dir / f"image.{format}")
img.save(filename, pnginfo=processing_utils.get_pil_metadata(img))
return filename
def img_array_to_temp_file(self, arr: np.ndarray, dir: str) -> str:
pil_image = _Image.fromarray(
processing_utils._convert(arr, np.uint8, force_copy=False)
)
return self.pil_to_temp_file(pil_image, dir, format="png")
def audio_to_temp_file(self, data: np.ndarray, sample_rate: int, format: str):
temp_dir = Path(self.DEFAULT_TEMP_DIR) / self.hash_bytes(data.tobytes())
temp_dir.mkdir(exist_ok=True, parents=True)
filename = str(temp_dir / f"audio.{format}")
processing_utils.audio_to_file(sample_rate, data, filename, format=format)
return filename
def file_bytes_to_file(self, data: bytes, file_name: str):
path = Path(self.DEFAULT_TEMP_DIR) / self.hash_bytes(data)
path.mkdir(exist_ok=True, parents=True)
path = path / Path(file_name).name
path.write_bytes(data)
return path
def get_config(self):
config = {
"label": self.label,
"show_label": self.show_label,
"container": self.container,
"scale": self.scale,
"min_width": self.min_width,
"interactive": self.interactive,
**super().get_config(),
}
if self.info:
config["info"] = self.info
return config
@staticmethod
def get_load_fn_and_initial_value(value):
if callable(value):
initial_value = value()
load_fn = value
else:
initial_value = value
load_fn = None
return load_fn, initial_value
def attach_load_event(self, callable: Callable, every: float | None):
"""Add a load event that runs `callable`, optionally every `every` seconds."""
self.load_event_to_attach = (callable, every)
def as_example(self, input_data):
"""Return the input data in a way that can be displayed by the examples dataset component in the front-end."""
return input_data
class FormComponent:
def get_expected_parent(self) -> type[Form] | None:
if getattr(self, "container", None) is False:
return None
return Form
def component(cls_name: str) -> Component:
obj = utils.component_or_layout_class(cls_name)()
if isinstance(obj, BlockContext):
raise ValueError(f"Invalid component: {obj.__class__}")
return obj
def get_component_instance(
comp: str | dict | Component, render: bool | None = None
) -> Component:
"""
Returns a component instance from a string, dict, or Component object.
Parameters:
comp: the component to instantiate. If a string, must be the name of a component, e.g. "dropdown". If a dict, must have a "name" key, e.g. {"name": "dropdown", "choices": ["a", "b"]}. If a Component object, will be returned as is.
render: whether to render the component. If True, renders the component (if not already rendered). If False, *unrenders* the component (if already rendered) -- this is useful when constructing an Interface or ChatInterface inside of a Blocks. If None, does not render or unrender the component.
"""
if isinstance(comp, str):
component_obj = component(comp)
elif isinstance(comp, dict):
name = comp.pop("name")
component_cls = utils.component_or_layout_class(name)
component_obj = component_cls(**comp)
if isinstance(component_obj, BlockContext):
raise ValueError(f"Invalid component: {name}")
elif isinstance(comp, Component):
component_obj = comp
else:
raise ValueError(
f"Component must provided as a `str` or `dict` or `Component` but is {comp}"
)
if render and not component_obj.is_rendered:
component_obj.render()
elif render is False and component_obj.is_rendered:
component_obj.unrender()
return component_obj