|
""" Handy utility functions. """ |
|
|
|
from __future__ import annotations |
|
|
|
import asyncio |
|
import copy |
|
import functools |
|
import importlib |
|
import inspect |
|
import json |
|
import json.decoder |
|
import os |
|
import pkgutil |
|
import pprint |
|
import random |
|
import re |
|
import threading |
|
import time |
|
import traceback |
|
import typing |
|
import warnings |
|
from abc import ABC, abstractmethod |
|
from contextlib import contextmanager |
|
from io import BytesIO |
|
from numbers import Number |
|
from pathlib import Path |
|
from types import GeneratorType |
|
from typing import ( |
|
TYPE_CHECKING, |
|
Any, |
|
Callable, |
|
Iterator, |
|
Optional, |
|
TypeVar, |
|
) |
|
|
|
import anyio |
|
import matplotlib |
|
import requests |
|
from gradio_client.serializing import Serializable |
|
from typing_extensions import ParamSpec |
|
|
|
import gradio |
|
from gradio.context import Context |
|
from gradio.strings import en |
|
|
|
if TYPE_CHECKING: |
|
from gradio.blocks import Block, BlockContext, Blocks |
|
from gradio.components import Component |
|
from gradio.routes import App |
|
|
|
JSON_PATH = os.path.join(os.path.dirname(gradio.__file__), "launches.json") |
|
|
|
P = ParamSpec("P") |
|
T = TypeVar("T") |
|
|
|
|
|
def get_package_version() -> str: |
|
try: |
|
package_json_data = ( |
|
pkgutil.get_data(__name__, "package.json").decode("utf-8").strip() |
|
) |
|
package_data = json.loads(package_json_data) |
|
version = package_data.get("version", "") |
|
return version |
|
except Exception: |
|
return "" |
|
|
|
|
|
def safe_get_lock() -> asyncio.Lock: |
|
"""Get asyncio.Lock() without fear of getting an Exception. |
|
|
|
Needed because in reload mode we import the Blocks object outside |
|
the main thread. |
|
""" |
|
try: |
|
asyncio.get_event_loop() |
|
return asyncio.Lock() |
|
except RuntimeError: |
|
return None |
|
|
|
|
|
class BaseReloader(ABC): |
|
@property |
|
@abstractmethod |
|
def running_app(self) -> App: |
|
pass |
|
|
|
def queue_changed(self, demo: Blocks): |
|
return ( |
|
hasattr(self.running_app.blocks, "_queue") and not hasattr(demo, "_queue") |
|
) or ( |
|
not hasattr(self.running_app.blocks, "_queue") and hasattr(demo, "_queue") |
|
) |
|
|
|
def swap_blocks(self, demo: Blocks): |
|
assert self.running_app.blocks |
|
|
|
|
|
if hasattr(self.running_app.blocks, "_queue"): |
|
self.running_app.blocks._queue.blocks_dependencies = demo.dependencies |
|
demo._queue = self.running_app.blocks._queue |
|
self.running_app.blocks = demo |
|
|
|
|
|
class SourceFileReloader(BaseReloader): |
|
def __init__( |
|
self, |
|
app: App, |
|
watch_dirs: list[str], |
|
watch_file: str, |
|
stop_event: threading.Event, |
|
change_event: threading.Event, |
|
demo_name: str = "demo", |
|
) -> None: |
|
super().__init__() |
|
self.app = app |
|
self.watch_dirs = watch_dirs |
|
self.watch_file = watch_file |
|
self.stop_event = stop_event |
|
self.change_event = change_event |
|
self.demo_name = demo_name |
|
|
|
@property |
|
def running_app(self) -> App: |
|
return self.app |
|
|
|
def should_watch(self) -> bool: |
|
return not self.stop_event.is_set() |
|
|
|
def stop(self) -> None: |
|
self.stop_event.set() |
|
|
|
def alert_change(self): |
|
self.change_event.set() |
|
|
|
def swap_blocks(self, demo: Blocks): |
|
super().swap_blocks(demo) |
|
self.alert_change() |
|
|
|
|
|
def watchfn(reloader: SourceFileReloader): |
|
"""Watch python files in a given module. |
|
|
|
get_changes is taken from uvicorn's default file watcher. |
|
""" |
|
|
|
|
|
|
|
|
|
from gradio.reload import reload_thread |
|
|
|
reload_thread.running_reload = True |
|
|
|
def get_changes() -> Path | None: |
|
for file in iter_py_files(): |
|
try: |
|
mtime = file.stat().st_mtime |
|
except OSError: |
|
continue |
|
|
|
old_time = mtimes.get(file) |
|
if old_time is None: |
|
mtimes[file] = mtime |
|
continue |
|
elif mtime > old_time: |
|
return file |
|
return None |
|
|
|
def iter_py_files() -> Iterator[Path]: |
|
for reload_dir in reload_dirs: |
|
for path in list(reload_dir.rglob("*.py")): |
|
yield path.resolve() |
|
|
|
module = None |
|
reload_dirs = [Path(dir_) for dir_ in reloader.watch_dirs] |
|
mtimes = {} |
|
while reloader.should_watch(): |
|
import sys |
|
|
|
changed = get_changes() |
|
if changed: |
|
print(f"Changes detected in: {changed}") |
|
|
|
|
|
dir_ = next(d for d in reload_dirs if is_in_or_equal(changed, d)) |
|
modules = list(sys.modules) |
|
for k in modules: |
|
v = sys.modules[k] |
|
sourcefile = getattr(v, "__file__", None) |
|
|
|
if ( |
|
sourcefile |
|
and dir_ == Path(inspect.getfile(gradio)).parent |
|
and sourcefile.endswith("reload.py") |
|
): |
|
continue |
|
if sourcefile and is_in_or_equal(sourcefile, dir_): |
|
del sys.modules[k] |
|
try: |
|
module = importlib.import_module(reloader.watch_file) |
|
module = importlib.reload(module) |
|
except Exception as e: |
|
print( |
|
f"Reloading {reloader.watch_file} failed with the following exception: " |
|
) |
|
traceback.print_exception(None, value=e, tb=None) |
|
mtimes = {} |
|
continue |
|
|
|
demo = getattr(module, reloader.demo_name) |
|
if reloader.queue_changed(demo): |
|
print( |
|
"Reloading failed. The new demo has a queue and the old one doesn't (or vice versa). " |
|
"Please launch your demo again" |
|
) |
|
else: |
|
reloader.swap_blocks(demo) |
|
mtimes = {} |
|
|
|
|
|
def colab_check() -> bool: |
|
""" |
|
Check if interface is launching from Google Colab |
|
:return is_colab (bool): True or False |
|
""" |
|
is_colab = False |
|
try: |
|
from IPython.core.getipython import get_ipython |
|
|
|
from_ipynb = get_ipython() |
|
if "google.colab" in str(from_ipynb): |
|
is_colab = True |
|
except (ImportError, NameError): |
|
pass |
|
return is_colab |
|
|
|
|
|
def kaggle_check() -> bool: |
|
return bool( |
|
os.environ.get("KAGGLE_KERNEL_RUN_TYPE") or os.environ.get("GFOOTBALL_DATA_DIR") |
|
) |
|
|
|
|
|
def sagemaker_check() -> bool: |
|
try: |
|
import boto3 |
|
|
|
client = boto3.client("sts") |
|
response = client.get_caller_identity() |
|
return "sagemaker" in response["Arn"].lower() |
|
except Exception: |
|
return False |
|
|
|
|
|
def ipython_check() -> bool: |
|
""" |
|
Check if interface is launching from iPython (not colab) |
|
:return is_ipython (bool): True or False |
|
""" |
|
is_ipython = False |
|
try: |
|
from IPython.core.getipython import get_ipython |
|
|
|
if get_ipython() is not None: |
|
is_ipython = True |
|
except (ImportError, NameError): |
|
pass |
|
return is_ipython |
|
|
|
|
|
def get_space() -> str | None: |
|
if os.getenv("SYSTEM") == "spaces": |
|
return os.getenv("SPACE_ID") |
|
return None |
|
|
|
|
|
def is_zero_gpu_space() -> bool: |
|
return os.getenv("SPACES_ZERO_GPU") == "true" |
|
|
|
|
|
def readme_to_html(article: str) -> str: |
|
try: |
|
response = requests.get(article, timeout=3) |
|
if response.status_code == requests.codes.ok: |
|
article = response.text |
|
except requests.exceptions.RequestException: |
|
pass |
|
return article |
|
|
|
|
|
def show_tip(interface: gradio.Blocks) -> None: |
|
if interface.show_tips and random.random() < 1.5: |
|
tip: str = random.choice(en["TIPS"]) |
|
print(f"Tip: {tip}") |
|
|
|
|
|
def launch_counter() -> None: |
|
try: |
|
if not os.path.exists(JSON_PATH): |
|
launches = {"launches": 1} |
|
with open(JSON_PATH, "w+") as j: |
|
json.dump(launches, j) |
|
else: |
|
with open(JSON_PATH) as j: |
|
launches = json.load(j) |
|
launches["launches"] += 1 |
|
if launches["launches"] in [25, 50, 150, 500, 1000]: |
|
print(en["BETA_INVITE"]) |
|
with open(JSON_PATH, "w") as j: |
|
j.write(json.dumps(launches)) |
|
except Exception: |
|
pass |
|
|
|
|
|
def get_default_args(func: Callable) -> list[Any]: |
|
signature = inspect.signature(func) |
|
return [ |
|
v.default if v.default is not inspect.Parameter.empty else None |
|
for v in signature.parameters.values() |
|
] |
|
|
|
|
|
def assert_configs_are_equivalent_besides_ids( |
|
config1: dict, config2: dict, root_keys: tuple = ("mode",) |
|
): |
|
"""Allows you to test if two different Blocks configs produce the same demo. |
|
|
|
Parameters: |
|
config1 (dict): nested dict with config from the first Blocks instance |
|
config2 (dict): nested dict with config from the second Blocks instance |
|
root_keys (Tuple): an interable consisting of which keys to test for equivalence at |
|
the root level of the config. By default, only "mode" is tested, |
|
so keys like "version" are ignored. |
|
""" |
|
config1 = copy.deepcopy(config1) |
|
config2 = copy.deepcopy(config2) |
|
pp = pprint.PrettyPrinter(indent=2) |
|
|
|
for key in root_keys: |
|
assert config1[key] == config2[key], f"Configs have different: {key}" |
|
|
|
assert len(config1["components"]) == len( |
|
config2["components"] |
|
), "# of components are different" |
|
|
|
def assert_same_components(config1_id, config2_id): |
|
c1 = list(filter(lambda c: c["id"] == config1_id, config1["components"]))[0] |
|
c2 = list(filter(lambda c: c["id"] == config2_id, config2["components"]))[0] |
|
c1 = copy.deepcopy(c1) |
|
c1.pop("id") |
|
c2 = copy.deepcopy(c2) |
|
c2.pop("id") |
|
assert json.dumps(c1) == json.dumps( |
|
c2 |
|
), f"{pp.pprint(c1)} does not match {pp.pprint(c2)}" |
|
|
|
def same_children_recursive(children1, chidren2): |
|
for child1, child2 in zip(children1, chidren2): |
|
assert_same_components(child1["id"], child2["id"]) |
|
if "children" in child1 or "children" in child2: |
|
same_children_recursive(child1["children"], child2["children"]) |
|
|
|
children1 = config1["layout"]["children"] |
|
children2 = config2["layout"]["children"] |
|
same_children_recursive(children1, children2) |
|
|
|
for d1, d2 in zip(config1["dependencies"], config2["dependencies"]): |
|
for t1, t2 in zip(d1.pop("targets"), d2.pop("targets")): |
|
assert_same_components(t1, t2) |
|
for i1, i2 in zip(d1.pop("inputs"), d2.pop("inputs")): |
|
assert_same_components(i1, i2) |
|
for o1, o2 in zip(d1.pop("outputs"), d2.pop("outputs")): |
|
assert_same_components(o1, o2) |
|
|
|
assert d1 == d2, f"{d1} does not match {d2}" |
|
|
|
return True |
|
|
|
|
|
def format_ner_list(input_string: str, ner_groups: list[dict[str, str | int]]): |
|
if len(ner_groups) == 0: |
|
return [(input_string, None)] |
|
|
|
output = [] |
|
end = 0 |
|
prev_end = 0 |
|
|
|
for group in ner_groups: |
|
entity, start, end = group["entity_group"], group["start"], group["end"] |
|
output.append((input_string[prev_end:start], None)) |
|
output.append((input_string[start:end], entity)) |
|
prev_end = end |
|
|
|
output.append((input_string[end:], None)) |
|
return output |
|
|
|
|
|
def delete_none(_dict: dict, skip_value: bool = False) -> dict: |
|
""" |
|
Delete keys whose values are None from a dictionary |
|
""" |
|
for key, value in list(_dict.items()): |
|
if skip_value and key == "value": |
|
continue |
|
elif value is None: |
|
del _dict[key] |
|
return _dict |
|
|
|
|
|
def resolve_singleton(_list: list[Any] | Any) -> Any: |
|
if len(_list) == 1: |
|
return _list[0] |
|
else: |
|
return _list |
|
|
|
|
|
def component_or_layout_class(cls_name: str) -> type[Component] | type[BlockContext]: |
|
""" |
|
Returns the component, template, or layout class with the given class name, or |
|
raises a ValueError if not found. |
|
|
|
Parameters: |
|
cls_name (str): lower-case string class name of a component |
|
Returns: |
|
cls: the component class |
|
""" |
|
import gradio.blocks |
|
import gradio.components |
|
import gradio.layouts |
|
import gradio.templates |
|
|
|
components = [ |
|
(name, cls) |
|
for name, cls in gradio.components.__dict__.items() |
|
if isinstance(cls, type) |
|
] |
|
templates = [ |
|
(name, cls) |
|
for name, cls in gradio.templates.__dict__.items() |
|
if isinstance(cls, type) |
|
] |
|
layouts = [ |
|
(name, cls) |
|
for name, cls in gradio.layouts.__dict__.items() |
|
if isinstance(cls, type) |
|
] |
|
for name, cls in components + templates + layouts: |
|
if name.lower() == cls_name.replace("_", "") and ( |
|
issubclass(cls, gradio.components.Component) |
|
or issubclass(cls, gradio.blocks.BlockContext) |
|
): |
|
return cls |
|
raise ValueError(f"No such component or layout: {cls_name}") |
|
|
|
|
|
def run_coro_in_background(func: Callable, *args, **kwargs): |
|
""" |
|
Runs coroutines in background. |
|
|
|
Warning, be careful to not use this function in other than FastAPI scope, because the event_loop has not started yet. |
|
You can use it in any scope reached by FastAPI app. |
|
|
|
correct scope examples: endpoints in routes, Blocks.process_api |
|
incorrect scope examples: Blocks.launch |
|
|
|
Use startup_events in routes.py if you need to run a coro in background in Blocks.launch(). |
|
|
|
|
|
Example: |
|
utils.run_coro_in_background(fn, *args, **kwargs) |
|
|
|
Args: |
|
func: |
|
*args: |
|
**kwargs: |
|
|
|
Returns: |
|
|
|
""" |
|
event_loop = asyncio.get_event_loop() |
|
return event_loop.create_task(func(*args, **kwargs)) |
|
|
|
|
|
def run_sync_iterator_async(iterator): |
|
"""Helper for yielding StopAsyncIteration from sync iterators.""" |
|
try: |
|
return next(iterator) |
|
except StopIteration: |
|
|
|
raise StopAsyncIteration() from None |
|
|
|
|
|
class SyncToAsyncIterator: |
|
"""Treat a synchronous iterator as async one.""" |
|
|
|
def __init__(self, iterator, limiter) -> None: |
|
self.iterator = iterator |
|
self.limiter = limiter |
|
|
|
def __aiter__(self): |
|
return self |
|
|
|
async def __anext__(self): |
|
return await anyio.to_thread.run_sync( |
|
run_sync_iterator_async, self.iterator, limiter=self.limiter |
|
) |
|
|
|
|
|
async def async_iteration(iterator): |
|
|
|
return await iterator.__anext__() |
|
|
|
|
|
@contextmanager |
|
def set_directory(path: Path | str): |
|
"""Context manager that sets the working directory to the given path.""" |
|
origin = Path().absolute() |
|
try: |
|
os.chdir(path) |
|
yield |
|
finally: |
|
os.chdir(origin) |
|
|
|
|
|
def sanitize_value_for_csv(value: str | Number) -> str | Number: |
|
""" |
|
Sanitizes a value that is being written to a CSV file to prevent CSV injection attacks. |
|
Reference: https://owasp.org/www-community/attacks/CSV_Injection |
|
""" |
|
if isinstance(value, Number): |
|
return value |
|
unsafe_prefixes = ["=", "+", "-", "@", "\t", "\n"] |
|
unsafe_sequences = [",=", ",+", ",-", ",@", ",\t", ",\n"] |
|
if any(value.startswith(prefix) for prefix in unsafe_prefixes) or any( |
|
sequence in value for sequence in unsafe_sequences |
|
): |
|
value = f"'{value}" |
|
return value |
|
|
|
|
|
def sanitize_list_for_csv(values: list[Any]) -> list[Any]: |
|
""" |
|
Sanitizes a list of values (or a list of list of values) that is being written to a |
|
CSV file to prevent CSV injection attacks. |
|
""" |
|
sanitized_values = [] |
|
for value in values: |
|
if isinstance(value, list): |
|
sanitized_value = [sanitize_value_for_csv(v) for v in value] |
|
sanitized_values.append(sanitized_value) |
|
else: |
|
sanitized_value = sanitize_value_for_csv(value) |
|
sanitized_values.append(sanitized_value) |
|
return sanitized_values |
|
|
|
|
|
def append_unique_suffix(name: str, list_of_names: list[str]): |
|
"""Appends a numerical suffix to `name` so that it does not appear in `list_of_names`.""" |
|
set_of_names: set[str] = set(list_of_names) |
|
if name not in set_of_names: |
|
return name |
|
else: |
|
suffix_counter = 1 |
|
new_name = f"{name}_{suffix_counter}" |
|
while new_name in set_of_names: |
|
suffix_counter += 1 |
|
new_name = f"{name}_{suffix_counter}" |
|
return new_name |
|
|
|
|
|
def validate_url(possible_url: str) -> bool: |
|
headers = {"User-Agent": "gradio (https://gradio.app/; [email protected])"} |
|
try: |
|
head_request = requests.head(possible_url, headers=headers) |
|
|
|
if head_request.status_code == 405 or head_request.status_code == 403: |
|
return requests.get(possible_url, headers=headers).ok |
|
return head_request.ok |
|
except Exception: |
|
return False |
|
|
|
|
|
def is_update(val): |
|
return isinstance(val, dict) and "update" in val.get("__type__", "") |
|
|
|
|
|
def get_continuous_fn(fn: Callable, every: float) -> Callable: |
|
def continuous_fn(*args): |
|
while True: |
|
output = fn(*args) |
|
if isinstance(output, GeneratorType): |
|
yield from output |
|
else: |
|
yield output |
|
time.sleep(every) |
|
|
|
return continuous_fn |
|
|
|
|
|
def function_wrapper( |
|
f, before_fn=None, before_args=None, after_fn=None, after_args=None |
|
): |
|
before_args = [] if before_args is None else before_args |
|
after_args = [] if after_args is None else after_args |
|
if inspect.isasyncgenfunction(f): |
|
|
|
@functools.wraps(f) |
|
async def asyncgen_wrapper(*args, **kwargs): |
|
if before_fn: |
|
before_fn(*before_args) |
|
async for response in f(*args, **kwargs): |
|
yield response |
|
if after_fn: |
|
after_fn(*after_args) |
|
|
|
return asyncgen_wrapper |
|
|
|
elif asyncio.iscoroutinefunction(f): |
|
|
|
@functools.wraps(f) |
|
async def async_wrapper(*args, **kwargs): |
|
if before_fn: |
|
before_fn(*before_args) |
|
response = await f(*args, **kwargs) |
|
if after_fn: |
|
after_fn(*after_args) |
|
return response |
|
|
|
return async_wrapper |
|
|
|
elif inspect.isgeneratorfunction(f): |
|
|
|
@functools.wraps(f) |
|
def gen_wrapper(*args, **kwargs): |
|
if before_fn: |
|
before_fn(*before_args) |
|
yield from f(*args, **kwargs) |
|
if after_fn: |
|
after_fn(*after_args) |
|
|
|
return gen_wrapper |
|
|
|
else: |
|
|
|
@functools.wraps(f) |
|
def wrapper(*args, **kwargs): |
|
if before_fn: |
|
before_fn(*before_args) |
|
response = f(*args, **kwargs) |
|
if after_fn: |
|
after_fn(*after_args) |
|
return response |
|
|
|
return wrapper |
|
|
|
|
|
def get_function_with_locals(fn: Callable, blocks: Blocks, event_id: str | None): |
|
def before_fn(blocks, event_id): |
|
from gradio.context import thread_data |
|
|
|
thread_data.blocks = blocks |
|
thread_data.event_id = event_id |
|
|
|
return function_wrapper(fn, before_fn=before_fn, before_args=(blocks, event_id)) |
|
|
|
|
|
async def cancel_tasks(task_ids: set[str]): |
|
matching_tasks = [ |
|
task for task in asyncio.all_tasks() if task.get_name() in task_ids |
|
] |
|
for task in matching_tasks: |
|
task.cancel() |
|
await asyncio.gather(*matching_tasks, return_exceptions=True) |
|
|
|
|
|
def set_task_name(task, session_hash: str, fn_index: int, batch: bool): |
|
if not batch: |
|
task.set_name(f"{session_hash}_{fn_index}") |
|
|
|
|
|
def get_cancel_function( |
|
dependencies: list[dict[str, Any]] |
|
) -> tuple[Callable, list[int]]: |
|
fn_to_comp = {} |
|
for dep in dependencies: |
|
if Context.root_block: |
|
fn_index = next( |
|
i for i, d in enumerate(Context.root_block.dependencies) if d == dep |
|
) |
|
fn_to_comp[fn_index] = [ |
|
Context.root_block.blocks[o] for o in dep["outputs"] |
|
] |
|
|
|
async def cancel(session_hash: str) -> None: |
|
task_ids = {f"{session_hash}_{fn}" for fn in fn_to_comp} |
|
await cancel_tasks(task_ids) |
|
|
|
return ( |
|
cancel, |
|
list(fn_to_comp.keys()), |
|
) |
|
|
|
|
|
def get_type_hints(fn): |
|
|
|
import gradio as gr |
|
from gradio import OAuthProfile, Request |
|
|
|
if inspect.isfunction(fn) or inspect.ismethod(fn): |
|
pass |
|
elif callable(fn): |
|
fn = fn.__call__ |
|
else: |
|
return {} |
|
|
|
try: |
|
return typing.get_type_hints(fn) |
|
except TypeError: |
|
|
|
|
|
|
|
type_hints = {} |
|
sig = inspect.signature(fn) |
|
for name, param in sig.parameters.items(): |
|
if param.annotation is inspect.Parameter.empty: |
|
continue |
|
if param.annotation == "gr.OAuthProfile | None": |
|
|
|
type_hints[name] = Optional[OAuthProfile] |
|
if "|" in str(param.annotation): |
|
continue |
|
|
|
|
|
|
|
|
|
try: |
|
type_hints[name] = typing._eval_type( |
|
typing.ForwardRef(param.annotation), globals(), locals() |
|
) |
|
except (NameError, TypeError): |
|
pass |
|
return type_hints |
|
|
|
|
|
def is_special_typed_parameter(name, parameter_types): |
|
from gradio.helpers import EventData |
|
from gradio.oauth import OAuthProfile |
|
from gradio.routes import Request |
|
|
|
"""Checks if parameter has a type hint designating it as a gr.Request, gr.EventData or gr.OAuthProfile.""" |
|
hint = parameter_types.get(name) |
|
if not hint: |
|
return False |
|
is_request = hint == Request |
|
is_oauth_arg = hint in (OAuthProfile, Optional[OAuthProfile]) |
|
is_event_data = inspect.isclass(hint) and issubclass(hint, EventData) |
|
return is_request or is_event_data or is_oauth_arg |
|
|
|
|
|
def check_function_inputs_match(fn: Callable, inputs: list, inputs_as_dict: bool): |
|
""" |
|
Checks if the input component set matches the function |
|
Returns: None if valid, a string error message if mismatch |
|
""" |
|
|
|
signature = inspect.signature(fn) |
|
parameter_types = get_type_hints(fn) |
|
min_args = 0 |
|
max_args = 0 |
|
infinity = -1 |
|
for name, param in signature.parameters.items(): |
|
has_default = param.default != param.empty |
|
if param.kind in [param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD]: |
|
if not is_special_typed_parameter(name, parameter_types): |
|
if not has_default: |
|
min_args += 1 |
|
max_args += 1 |
|
elif param.kind == param.VAR_POSITIONAL: |
|
max_args = infinity |
|
elif param.kind == param.KEYWORD_ONLY and not has_default: |
|
return f"Keyword-only args must have default values for function {fn}" |
|
arg_count = 1 if inputs_as_dict else len(inputs) |
|
if min_args == max_args and max_args != arg_count: |
|
warnings.warn( |
|
f"Expected {max_args} arguments for function {fn}, received {arg_count}." |
|
) |
|
if arg_count < min_args: |
|
warnings.warn( |
|
f"Expected at least {min_args} arguments for function {fn}, received {arg_count}." |
|
) |
|
if max_args != infinity and arg_count > max_args: |
|
warnings.warn( |
|
f"Expected maximum {max_args} arguments for function {fn}, received {arg_count}." |
|
) |
|
|
|
|
|
def concurrency_count_warning(queue: Callable[P, T]) -> Callable[P, T]: |
|
@functools.wraps(queue) |
|
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: |
|
_self, *positional = args |
|
if is_zero_gpu_space() and ( |
|
len(positional) >= 1 or "concurrency_count" in kwargs |
|
): |
|
warnings.warn( |
|
"Queue concurrency_count on ZeroGPU Spaces cannot be overridden " |
|
"and is always equal to Block's max_threads. " |
|
"Consider setting max_threads value on the Block instead" |
|
) |
|
return queue(*args, **kwargs) |
|
|
|
return wrapper |
|
|
|
|
|
class TupleNoPrint(tuple): |
|
|
|
def __repr__(self): |
|
return "" |
|
|
|
def __str__(self): |
|
return "" |
|
|
|
|
|
class MatplotlibBackendMananger: |
|
def __enter__(self): |
|
self._original_backend = matplotlib.get_backend() |
|
matplotlib.use("agg") |
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
matplotlib.use(self._original_backend) |
|
|
|
|
|
def tex2svg(formula, *args): |
|
with MatplotlibBackendMananger(): |
|
import matplotlib.pyplot as plt |
|
|
|
fontsize = 20 |
|
dpi = 300 |
|
plt.rc("mathtext", fontset="cm") |
|
fig = plt.figure(figsize=(0.01, 0.01)) |
|
fig.text(0, 0, rf"${formula}$", fontsize=fontsize) |
|
output = BytesIO() |
|
fig.savefig( |
|
output, |
|
dpi=dpi, |
|
transparent=True, |
|
format="svg", |
|
bbox_inches="tight", |
|
pad_inches=0.0, |
|
) |
|
plt.close(fig) |
|
output.seek(0) |
|
xml_code = output.read().decode("utf-8") |
|
svg_start = xml_code.index("<svg ") |
|
svg_code = xml_code[svg_start:] |
|
svg_code = re.sub(r"<metadata>.*<\/metadata>", "", svg_code, flags=re.DOTALL) |
|
svg_code = re.sub(r' width="[^"]+"', "", svg_code) |
|
height_match = re.search(r'height="([\d.]+)pt"', svg_code) |
|
if height_match: |
|
height = float(height_match.group(1)) |
|
new_height = height / fontsize |
|
svg_code = re.sub( |
|
r'height="[\d.]+pt"', f'height="{new_height}em"', svg_code |
|
) |
|
copy_code = f"<span style='font-size: 0px'>{formula}</span>" |
|
return f"{copy_code}{svg_code}" |
|
|
|
|
|
def abspath(path: str | Path) -> Path: |
|
"""Returns absolute path of a str or Path path, but does not resolve symlinks.""" |
|
path = Path(path) |
|
|
|
if path.is_absolute(): |
|
return path |
|
|
|
|
|
is_symlink = path.is_symlink() or any( |
|
parent.is_symlink() for parent in path.parents |
|
) |
|
|
|
if is_symlink or path == path.resolve(): |
|
return Path.cwd() / path |
|
else: |
|
return path.resolve() |
|
|
|
|
|
def is_in_or_equal(path_1: str | Path, path_2: str | Path): |
|
""" |
|
True if path_1 is a descendant (i.e. located within) path_2 or if the paths are the |
|
same, returns False otherwise. |
|
Parameters: |
|
path_1: str or Path (should be a file) |
|
path_2: str or Path (can be a file or directory) |
|
""" |
|
path_1, path_2 = abspath(path_1), abspath(path_2) |
|
try: |
|
if str(path_1.relative_to(path_2)).startswith(".."): |
|
return False |
|
except ValueError: |
|
return False |
|
return True |
|
|
|
|
|
def get_serializer_name(block: Block) -> str | None: |
|
if not hasattr(block, "serialize"): |
|
return None |
|
|
|
def get_class_that_defined_method(meth: Callable): |
|
|
|
if isinstance(meth, functools.partial): |
|
return get_class_that_defined_method(meth.func) |
|
if inspect.ismethod(meth) or ( |
|
inspect.isbuiltin(meth) |
|
and getattr(meth, "__self__", None) is not None |
|
and getattr(meth.__self__, "__class__", None) |
|
): |
|
for cls in inspect.getmro(meth.__self__.__class__): |
|
|
|
if issubclass(cls, Serializable) and "gradio_client" in cls.__module__: |
|
return cls |
|
if meth.__name__ in cls.__dict__: |
|
return cls |
|
meth = getattr(meth, "__func__", meth) |
|
if inspect.isfunction(meth): |
|
cls = getattr( |
|
inspect.getmodule(meth), |
|
meth.__qualname__.split(".<locals>", 1)[0].rsplit(".", 1)[0], |
|
None, |
|
) |
|
if isinstance(cls, type): |
|
return cls |
|
return getattr(meth, "__objclass__", None) |
|
|
|
cls = get_class_that_defined_method(block.serialize) |
|
if cls: |
|
return cls.__name__ |
|
|
|
|
|
HTML_TAG_RE = re.compile("<.*?>") |
|
|
|
|
|
def remove_html_tags(raw_html: str | None) -> str: |
|
return re.sub(HTML_TAG_RE, "", raw_html or "") |
|
|
|
|
|
def find_user_stack_level() -> int: |
|
""" |
|
Find the first stack frame not inside Gradio. |
|
""" |
|
frame = inspect.currentframe() |
|
n = 0 |
|
while frame: |
|
fname = inspect.getfile(frame) |
|
if "/gradio/" not in fname.replace(os.sep, "/"): |
|
break |
|
frame = frame.f_back |
|
n += 1 |
|
return n |
|
|