Spaces:
Runtime error
Runtime error
# The contents of this file are automatically written by | |
# tools/generate_schema_wrapper.py. Do not modify directly. | |
from __future__ import annotations | |
import contextlib | |
import copy | |
import datetime as dt | |
import inspect | |
import json | |
import sys | |
import textwrap | |
from collections import defaultdict | |
from collections.abc import Iterable, Iterator, Mapping, Sequence | |
from functools import partial | |
from importlib.metadata import version as importlib_version | |
from itertools import chain, zip_longest | |
from math import ceil | |
from typing import ( | |
TYPE_CHECKING, | |
Any, | |
Final, | |
Generic, | |
Literal, | |
TypeVar, | |
Union, | |
cast, | |
overload, | |
) | |
from typing_extensions import TypeAlias | |
import jsonschema | |
import jsonschema.exceptions | |
import jsonschema.validators | |
import narwhals.stable.v1 as nw | |
from packaging.version import Version | |
# This leads to circular imports with the vegalite module. Currently, this works | |
# but be aware that when you access it in this script, the vegalite module might | |
# not yet be fully instantiated in case your code is being executed during import time | |
from altair import vegalite | |
if sys.version_info >= (3, 12): | |
from typing import Protocol, TypeAliasType, runtime_checkable | |
else: | |
from typing_extensions import Protocol, TypeAliasType, runtime_checkable | |
if TYPE_CHECKING: | |
from types import ModuleType | |
from typing import ClassVar | |
from referencing import Registry | |
from altair.typing import ChartType | |
if sys.version_info >= (3, 13): | |
from typing import TypeIs | |
else: | |
from typing_extensions import TypeIs | |
if sys.version_info >= (3, 11): | |
from typing import Never, Self | |
else: | |
from typing_extensions import Never, Self | |
_OptionalModule: TypeAlias = "ModuleType | None" | |
ValidationErrorList: TypeAlias = list[jsonschema.exceptions.ValidationError] | |
GroupedValidationErrors: TypeAlias = dict[str, ValidationErrorList] | |
# This URI is arbitrary and could be anything else. It just cannot be an empty | |
# string as we need to reference the schema registered in | |
# the referencing.Registry. | |
_VEGA_LITE_ROOT_URI: Final = "urn:vega-lite-schema" | |
# Ideally, jsonschema specification would be parsed from the current Vega-Lite | |
# schema instead of being hardcoded here as a default value. | |
# However, due to circular imports between this module and the altair.vegalite | |
# modules, this information is not yet available at this point as altair.vegalite | |
# is only partially loaded. The draft version which is used is unlikely to | |
# change often so it's ok to keep this. There is also a test which validates | |
# that this value is always the same as in the Vega-Lite schema. | |
_DEFAULT_JSON_SCHEMA_DRAFT_URL: Final = "http://json-schema.org/draft-07/schema#" | |
# If DEBUG_MODE is True, then schema objects are converted to dict and | |
# validated at creation time. This slows things down, particularly for | |
# larger specs, but leads to much more useful tracebacks for the user. | |
# Individual schema classes can override this by setting the | |
# class-level _class_is_valid_at_instantiation attribute to False | |
DEBUG_MODE: bool = True | |
jsonschema_version_str = importlib_version("jsonschema") | |
def enable_debug_mode() -> None: | |
global DEBUG_MODE | |
DEBUG_MODE = True | |
def disable_debug_mode() -> None: | |
global DEBUG_MODE | |
DEBUG_MODE = False | |
def debug_mode(arg: bool) -> Iterator[None]: | |
global DEBUG_MODE | |
original = DEBUG_MODE | |
DEBUG_MODE = arg | |
try: | |
yield | |
finally: | |
DEBUG_MODE = original | |
def validate_jsonschema( | |
spec: Any, | |
schema: dict[str, Any], | |
rootschema: dict[str, Any] | None = ..., | |
*, | |
raise_error: Literal[True] = ..., | |
) -> Never: ... | |
def validate_jsonschema( | |
spec: Any, | |
schema: dict[str, Any], | |
rootschema: dict[str, Any] | None = ..., | |
*, | |
raise_error: Literal[False], | |
) -> jsonschema.exceptions.ValidationError | None: ... | |
def validate_jsonschema( | |
spec, | |
schema: dict[str, Any], | |
rootschema: dict[str, Any] | None = None, | |
*, | |
raise_error: bool = True, | |
) -> jsonschema.exceptions.ValidationError | None: | |
""" | |
Validates the passed in spec against the schema in the context of the rootschema. | |
If any errors are found, they are deduplicated and prioritized | |
and only the most relevant errors are kept. Errors are then either raised | |
or returned, depending on the value of `raise_error`. | |
""" | |
errors = _get_errors_from_spec(spec, schema, rootschema=rootschema) | |
if errors: | |
leaf_errors = _get_leaves_of_error_tree(errors) | |
grouped_errors = _group_errors_by_json_path(leaf_errors) | |
grouped_errors = _subset_to_most_specific_json_paths(grouped_errors) | |
grouped_errors = _deduplicate_errors(grouped_errors) | |
# Nothing special about this first error but we need to choose one | |
# which can be raised | |
main_error: Any = next(iter(grouped_errors.values()))[0] | |
# All errors are then attached as a new attribute to ValidationError so that | |
# they can be used in SchemaValidationError to craft a more helpful | |
# error message. Setting a new attribute like this is not ideal as | |
# it then no longer matches the type ValidationError. It would be better | |
# to refactor this function to never raise but only return errors. | |
main_error._all_errors = grouped_errors | |
if raise_error: | |
raise main_error | |
else: | |
return main_error | |
else: | |
return None | |
def _get_errors_from_spec( | |
spec: dict[str, Any], | |
schema: dict[str, Any], | |
rootschema: dict[str, Any] | None = None, | |
) -> ValidationErrorList: | |
""" | |
Uses the relevant jsonschema validator to validate the passed in spec against the schema using the rootschema to resolve references. | |
The schema and rootschema themselves are not validated but instead considered as valid. | |
""" | |
# We don't use jsonschema.validate as this would validate the schema itself. | |
# Instead, we pass the schema directly to the validator class. This is done for | |
# two reasons: The schema comes from Vega-Lite and is not based on the user | |
# input, therefore there is no need to validate it in the first place. Furthermore, | |
# the "uri-reference" format checker fails for some of the references as URIs in | |
# "$ref" are not encoded, | |
# e.g. '#/definitions/ValueDefWithCondition<MarkPropFieldOrDatumDef, | |
# (Gradient|string|null)>' would be a valid $ref in a Vega-Lite schema but | |
# it is not a valid URI reference due to the characters such as '<'. | |
json_schema_draft_url = _get_json_schema_draft_url(rootschema or schema) | |
validator_cls = jsonschema.validators.validator_for( | |
{"$schema": json_schema_draft_url} | |
) | |
validator_kwargs: dict[str, Any] = {} | |
if hasattr(validator_cls, "FORMAT_CHECKER"): | |
validator_kwargs["format_checker"] = validator_cls.FORMAT_CHECKER | |
if _use_referencing_library(): | |
schema = _prepare_references_in_schema(schema) | |
validator_kwargs["registry"] = _get_referencing_registry( | |
rootschema or schema, json_schema_draft_url | |
) | |
else: | |
# No resolver is necessary if the schema is already the full schema | |
validator_kwargs["resolver"] = ( | |
jsonschema.RefResolver.from_schema(rootschema) | |
if rootschema is not None | |
else None | |
) | |
validator = validator_cls(schema, **validator_kwargs) | |
errors = list(validator.iter_errors(spec)) | |
return errors | |
def _get_json_schema_draft_url(schema: dict[str, Any]) -> str: | |
return schema.get("$schema", _DEFAULT_JSON_SCHEMA_DRAFT_URL) | |
def _use_referencing_library() -> bool: | |
"""In version 4.18.0, the jsonschema package deprecated RefResolver in favor of the referencing library.""" | |
return Version(jsonschema_version_str) >= Version("4.18") | |
def _prepare_references_in_schema(schema: dict[str, Any]) -> dict[str, Any]: | |
# Create a copy so that $ref is not modified in the original schema in case | |
# that it would still reference a dictionary which might be attached to | |
# an Altair class _schema attribute | |
schema = copy.deepcopy(schema) | |
def _prepare_refs(d: dict[str, Any]) -> dict[str, Any]: | |
""" | |
Add _VEGA_LITE_ROOT_URI in front of all $ref values. | |
This function recursively iterates through the whole dictionary. | |
$ref values can only be nested in dictionaries or lists | |
as the passed in `d` dictionary comes from the Vega-Lite json schema | |
and in json we only have arrays (-> lists in Python) and objects | |
(-> dictionaries in Python) which we need to iterate through. | |
""" | |
for key, value in d.items(): | |
if key == "$ref": | |
d[key] = _VEGA_LITE_ROOT_URI + d[key] | |
elif isinstance(value, dict): | |
d[key] = _prepare_refs(value) | |
elif isinstance(value, list): | |
prepared_values = [] | |
for v in value: | |
if isinstance(v, dict): | |
v = _prepare_refs(v) | |
prepared_values.append(v) | |
d[key] = prepared_values | |
return d | |
schema = _prepare_refs(schema) | |
return schema | |
# We do not annotate the return value here as the referencing library is not always | |
# available and this function is only executed in those cases. | |
def _get_referencing_registry( | |
rootschema: dict[str, Any], json_schema_draft_url: str | None = None | |
) -> Registry: | |
# Referencing is a dependency of newer jsonschema versions, starting with the | |
# version that is specified in _use_referencing_library and we therefore | |
# can expect that it is installed if the function returns True. | |
# We ignore 'import' mypy errors which happen when the referencing library | |
# is not installed. That's ok as in these cases this function is not called. | |
# We also have to ignore 'unused-ignore' errors as mypy raises those in case | |
# referencing is installed. | |
import referencing # type: ignore[import,unused-ignore] | |
import referencing.jsonschema # type: ignore[import,unused-ignore] | |
if json_schema_draft_url is None: | |
json_schema_draft_url = _get_json_schema_draft_url(rootschema) | |
specification = referencing.jsonschema.specification_with(json_schema_draft_url) | |
resource = specification.create_resource(rootschema) | |
return referencing.Registry().with_resource( | |
uri=_VEGA_LITE_ROOT_URI, resource=resource | |
) | |
def _json_path(err: jsonschema.exceptions.ValidationError) -> str: | |
""" | |
Drop in replacement for the .json_path property of the jsonschema ValidationError class. | |
This is not available as property for ValidationError with jsonschema<4.0.1. | |
More info, see https://github.com/vega/altair/issues/3038. | |
""" | |
path = "$" | |
for elem in err.absolute_path: | |
if isinstance(elem, int): | |
path += "[" + str(elem) + "]" | |
else: | |
path += "." + elem | |
return path | |
def _group_errors_by_json_path( | |
errors: ValidationErrorList, | |
) -> GroupedValidationErrors: | |
""" | |
Groups errors by the `json_path` attribute of the jsonschema ValidationError class. | |
This attribute contains the path to the offending element within | |
a chart specification and can therefore be considered as an identifier of an | |
'issue' in the chart that needs to be fixed. | |
""" | |
errors_by_json_path = defaultdict(list) | |
for err in errors: | |
err_key = getattr(err, "json_path", _json_path(err)) | |
errors_by_json_path[err_key].append(err) | |
return dict(errors_by_json_path) | |
def _get_leaves_of_error_tree( | |
errors: ValidationErrorList, | |
) -> ValidationErrorList: | |
""" | |
For each error in `errors`, it traverses down the "error tree" that is generated by the jsonschema library to find and return all "leaf" errors. | |
These are errors which have no further errors that caused it and so they are the most specific errors | |
with the most specific error messages. | |
""" | |
leaves: ValidationErrorList = [] | |
for err in errors: | |
if err.context: | |
# This means that the error `err` was caused by errors in subschemas. | |
# The list of errors from the subschemas are available in the property | |
# `context`. | |
leaves.extend(_get_leaves_of_error_tree(err.context)) | |
else: | |
leaves.append(err) | |
return leaves | |
def _subset_to_most_specific_json_paths( | |
errors_by_json_path: GroupedValidationErrors, | |
) -> GroupedValidationErrors: | |
""" | |
Removes key (json path), value (errors) pairs where the json path is fully contained in another json path. | |
For example if `errors_by_json_path` has two keys, `$.encoding.X` and `$.encoding.X.tooltip`, | |
then the first one will be removed and only the second one is returned. | |
This is done under the assumption that more specific json paths give more helpful error messages to the user. | |
""" | |
errors_by_json_path_specific: GroupedValidationErrors = {} | |
for json_path, errors in errors_by_json_path.items(): | |
if not _contained_at_start_of_one_of_other_values( | |
json_path, list(errors_by_json_path.keys()) | |
): | |
errors_by_json_path_specific[json_path] = errors | |
return errors_by_json_path_specific | |
def _contained_at_start_of_one_of_other_values(x: str, values: Sequence[str]) -> bool: | |
# Does not count as "contained at start of other value" if the values are | |
# the same. These cases should be handled separately | |
return any(value.startswith(x) for value in values if x != value) | |
def _deduplicate_errors( | |
grouped_errors: GroupedValidationErrors, | |
) -> GroupedValidationErrors: | |
""" | |
Some errors have very similar error messages or are just in general not helpful for a user. | |
This function removes as many of these cases as possible and | |
can be extended over time to handle new cases that come up. | |
""" | |
grouped_errors_deduplicated: GroupedValidationErrors = {} | |
for json_path, element_errors in grouped_errors.items(): | |
errors_by_validator = _group_errors_by_validator(element_errors) | |
deduplication_functions = { | |
"enum": _deduplicate_enum_errors, | |
"additionalProperties": _deduplicate_additional_properties_errors, | |
} | |
deduplicated_errors: ValidationErrorList = [] | |
for validator, errors in errors_by_validator.items(): | |
deduplication_func = deduplication_functions.get(validator) | |
if deduplication_func is not None: | |
errors = deduplication_func(errors) | |
deduplicated_errors.extend(_deduplicate_by_message(errors)) | |
# Removes any ValidationError "'value' is a required property" as these | |
# errors are unlikely to be the relevant ones for the user. They come from | |
# validation against a schema definition where the output of `alt.value` | |
# would be valid. However, if a user uses `alt.value`, the `value` keyword | |
# is included automatically from that function and so it's unlikely | |
# that this was what the user intended if the keyword is not present | |
# in the first place. | |
deduplicated_errors = [ | |
err for err in deduplicated_errors if not _is_required_value_error(err) | |
] | |
grouped_errors_deduplicated[json_path] = deduplicated_errors | |
return grouped_errors_deduplicated | |
def _is_required_value_error(err: jsonschema.exceptions.ValidationError) -> bool: | |
return err.validator == "required" and err.validator_value == ["value"] | |
def _group_errors_by_validator(errors: ValidationErrorList) -> GroupedValidationErrors: | |
""" | |
Groups the errors by the json schema "validator" that casued the error. | |
For example if the error is that a value is not one of an enumeration in the json schema | |
then the "validator" is `"enum"`, if the error is due to an unknown property that | |
was set although no additional properties are allowed then "validator" is | |
`"additionalProperties`, etc. | |
""" | |
errors_by_validator: defaultdict[str, ValidationErrorList] = defaultdict(list) | |
for err in errors: | |
# Ignore mypy error as err.validator as it wrongly sees err.validator | |
# as of type Optional[Validator] instead of str which it is according | |
# to the documentation and all tested cases | |
errors_by_validator[err.validator].append(err) # type: ignore[index] | |
return dict(errors_by_validator) | |
def _deduplicate_enum_errors(errors: ValidationErrorList) -> ValidationErrorList: | |
""" | |
Deduplicate enum errors by removing the errors where the allowed values are a subset of another error. | |
For example, if `enum` contains two errors and one has `validator_value` (i.e. accepted values) ["A", "B"] and the | |
other one ["A", "B", "C"] then the first one is removed and the final | |
`enum` list only contains the error with ["A", "B", "C"]. | |
""" | |
if len(errors) > 1: | |
# Values (and therefore `validator_value`) of an enum are always arrays, | |
# see https://json-schema.org/understanding-json-schema/reference/generic.html#enumerated-values | |
# which is why we can use join below | |
value_strings = [",".join(err.validator_value) for err in errors] # type: ignore | |
longest_enums: ValidationErrorList = [] | |
for value_str, err in zip(value_strings, errors): | |
if not _contained_at_start_of_one_of_other_values(value_str, value_strings): | |
longest_enums.append(err) | |
errors = longest_enums | |
return errors | |
def _deduplicate_additional_properties_errors( | |
errors: ValidationErrorList, | |
) -> ValidationErrorList: | |
""" | |
If there are multiple additional property errors it usually means that the offending element was validated against multiple schemas and its parent is a common anyOf validator. | |
The error messages produced from these cases are usually | |
very similar and we just take the shortest one. For example, | |
the following 3 errors are raised for the `unknown` channel option in | |
`alt.X("variety", unknown=2)`: | |
- "Additional properties are not allowed ('unknown' was unexpected)" | |
- "Additional properties are not allowed ('field', 'unknown' were unexpected)" | |
- "Additional properties are not allowed ('field', 'type', 'unknown' were unexpected)". | |
""" | |
if len(errors) > 1: | |
# Test if all parent errors are the same anyOf error and only do | |
# the prioritization in these cases. Can't think of a chart spec where this | |
# would not be the case but still allow for it below to not break anything. | |
parent = errors[0].parent | |
if ( | |
parent is not None | |
and parent.validator == "anyOf" | |
# Use [1:] as don't have to check for first error as it was used | |
# above to define `parent` | |
and all(err.parent is parent for err in errors[1:]) | |
): | |
errors = [min(errors, key=lambda x: len(x.message))] | |
return errors | |
def _deduplicate_by_message(errors: ValidationErrorList) -> ValidationErrorList: | |
"""Deduplicate errors by message. This keeps the original order in case it was chosen intentionally.""" | |
return list({e.message: e for e in errors}.values()) | |
def _subclasses(cls: type[Any]) -> Iterator[type[Any]]: | |
"""Breadth-first sequence of all classes which inherit from cls.""" | |
seen = set() | |
current_set = {cls} | |
while current_set: | |
seen |= current_set | |
current_set = set.union(*(set(cls.__subclasses__()) for cls in current_set)) | |
for cls in current_set - seen: | |
yield cls | |
def _from_array_like(obj: Iterable[Any], /) -> list[Any]: | |
try: | |
ser = nw.from_native(obj, strict=True, series_only=True) | |
return ser.to_list() | |
except TypeError: | |
return list(obj) | |
def _from_date_datetime(obj: dt.date | dt.datetime, /) -> dict[str, Any]: | |
""" | |
Parse native `datetime.(date|datetime)` into a `DateTime`_ schema. | |
.. _DateTime: | |
https://vega.github.io/vega-lite/docs/datetime.html | |
""" | |
result: dict[str, Any] = {"year": obj.year, "month": obj.month, "date": obj.day} | |
if isinstance(obj, dt.datetime): | |
if obj.time() != dt.time.min: | |
us = obj.microsecond | |
ms = us if us == 0 else us // 1_000 | |
result.update( | |
hours=obj.hour, minutes=obj.minute, seconds=obj.second, milliseconds=ms | |
) | |
if tzinfo := obj.tzinfo: | |
if tzinfo is dt.timezone.utc: | |
result["utc"] = True | |
else: | |
msg = ( | |
f"Unsupported timezone {tzinfo!r}.\n" | |
"Only `'UTC'` or naive (local) datetimes are permitted.\n" | |
"See https://altair-viz.github.io/user_guide/generated/core/altair.DateTime.html" | |
) | |
raise TypeError(msg) | |
return result | |
def _todict(obj: Any, context: dict[str, Any] | None, np_opt: Any, pd_opt: Any) -> Any: # noqa: C901 | |
"""Convert an object to a dict representation.""" | |
if np_opt is not None: | |
np = np_opt | |
if isinstance(obj, np.ndarray): | |
return [_todict(v, context, np_opt, pd_opt) for v in obj] | |
elif isinstance(obj, np.number): | |
return float(obj) | |
elif isinstance(obj, np.datetime64): | |
result = str(obj) | |
if "T" not in result: | |
# See https://github.com/vega/altair/issues/1027 for why this is necessary. | |
result += "T00:00:00" | |
return result | |
if isinstance(obj, SchemaBase): | |
return obj.to_dict(validate=False, context=context) | |
elif isinstance(obj, (list, tuple)): | |
return [_todict(v, context, np_opt, pd_opt) for v in obj] | |
elif isinstance(obj, dict): | |
return { | |
k: _todict(v, context, np_opt, pd_opt) | |
for k, v in obj.items() | |
if v is not Undefined | |
} | |
elif isinstance(obj, SchemaLike): | |
return obj.to_dict() | |
elif pd_opt is not None and isinstance(obj, pd_opt.Timestamp): | |
return pd_opt.Timestamp(obj).isoformat() | |
elif _is_iterable(obj, exclude=(str, bytes)): | |
return _todict(_from_array_like(obj), context, np_opt, pd_opt) | |
elif isinstance(obj, dt.date): | |
return _from_date_datetime(obj) | |
else: | |
return obj | |
def _resolve_references( | |
schema: dict[str, Any], rootschema: dict[str, Any] | None = None | |
) -> dict[str, Any]: | |
"""Resolve schema references until there is no $ref anymore in the top-level of the dictionary.""" | |
if _use_referencing_library(): | |
registry = _get_referencing_registry(rootschema or schema) | |
# Using a different variable name to show that this is not the | |
# jsonschema.RefResolver but instead a Resolver from the referencing | |
# library | |
referencing_resolver = registry.resolver() | |
while "$ref" in schema: | |
schema = referencing_resolver.lookup( | |
_VEGA_LITE_ROOT_URI + schema["$ref"] | |
).contents | |
else: | |
resolver = jsonschema.RefResolver.from_schema(rootschema or schema) | |
while "$ref" in schema: | |
with resolver.resolving(schema["$ref"]) as resolved: | |
schema = resolved | |
return schema | |
class SchemaValidationError(jsonschema.ValidationError): | |
def __init__(self, obj: SchemaBase, err: jsonschema.ValidationError) -> None: | |
""" | |
A wrapper for ``jsonschema.ValidationError`` with friendlier traceback. | |
Parameters | |
---------- | |
obj | |
The instance that failed ``self.validate(...)``. | |
err | |
The original ``ValidationError``. | |
Notes | |
----- | |
We do not raise `from err` as else the resulting traceback is very long | |
as it contains part of the Vega-Lite schema. | |
It would also first show the less helpful `ValidationError` instead of | |
the more user friendly `SchemaValidationError`. | |
""" | |
super().__init__(**err._contents()) | |
self.obj = obj | |
self._errors: GroupedValidationErrors = getattr( | |
err, "_all_errors", {getattr(err, "json_path", _json_path(err)): [err]} | |
) | |
# This is the message from err | |
self._original_message = self.message | |
self.message = self._get_message() | |
def __str__(self) -> str: | |
return self.message | |
def _get_message(self) -> str: | |
def indent_second_line_onwards(message: str, indent: int = 4) -> str: | |
modified_lines: list[str] = [] | |
for idx, line in enumerate(message.split("\n")): | |
if idx > 0 and len(line) > 0: | |
line = " " * indent + line | |
modified_lines.append(line) | |
return "\n".join(modified_lines) | |
error_messages: list[str] = [] | |
# Only show a maximum of 3 errors as else the final message returned by this | |
# method could get very long. | |
for errors in list(self._errors.values())[:3]: | |
error_messages.append(self._get_message_for_errors_group(errors)) | |
message = "" | |
if len(error_messages) > 1: | |
error_messages = [ | |
indent_second_line_onwards(f"Error {error_id}: {m}") | |
for error_id, m in enumerate(error_messages, start=1) | |
] | |
message += "Multiple errors were found.\n\n" | |
message += "\n\n".join(error_messages) | |
return message | |
def _get_message_for_errors_group( | |
self, | |
errors: ValidationErrorList, | |
) -> str: | |
if errors[0].validator == "additionalProperties": | |
# During development, we only found cases where an additionalProperties | |
# error was raised if that was the only error for the offending instance | |
# as identifiable by the json path. Therefore, we just check here the first | |
# error. However, other constellations might exist in which case | |
# this should be adapted so that other error messages are shown as well. | |
message = self._get_additional_properties_error_message(errors[0]) | |
else: | |
message = self._get_default_error_message(errors=errors) | |
return message.strip() | |
def _get_additional_properties_error_message( | |
self, | |
error: jsonschema.exceptions.ValidationError, | |
) -> str: | |
"""Output all existing parameters when an unknown parameter is specified.""" | |
altair_cls = self._get_altair_class_for_error(error) | |
param_dict_keys = inspect.signature(altair_cls).parameters.keys() | |
param_names_table = self._format_params_as_table(param_dict_keys) | |
# Error messages for these errors look like this: | |
# "Additional properties are not allowed ('unknown' was unexpected)" | |
# Line below extracts "unknown" from this string | |
parameter_name = error.message.split("('")[-1].split("'")[0] | |
message = f"""\ | |
`{altair_cls.__name__}` has no parameter named '{parameter_name}' | |
Existing parameter names are: | |
{param_names_table} | |
See the help for `{altair_cls.__name__}` to read the full description of these parameters""" | |
return message | |
def _get_altair_class_for_error( | |
self, error: jsonschema.exceptions.ValidationError | |
) -> type[SchemaBase]: | |
""" | |
Try to get the lowest class possible in the chart hierarchy so it can be displayed in the error message. | |
This should lead to more informative error messages pointing the user closer to the source of the issue. | |
""" | |
for prop_name in reversed(error.absolute_path): | |
# Check if str as e.g. first item can be a 0 | |
if isinstance(prop_name, str): | |
potential_class_name = prop_name[0].upper() + prop_name[1:] | |
cls = getattr(vegalite, potential_class_name, None) | |
if cls is not None: | |
break | |
else: | |
# Did not find a suitable class based on traversing the path so we fall | |
# back on the class of the top-level object which created | |
# the SchemaValidationError | |
cls = self.obj.__class__ | |
return cls | |
def _format_params_as_table(param_dict_keys: Iterable[str]) -> str: | |
"""Format param names into a table so that they are easier to read.""" | |
param_names: tuple[str, ...] | |
name_lengths: tuple[int, ...] | |
param_names, name_lengths = zip( | |
*[ | |
(name, len(name)) | |
for name in param_dict_keys | |
if name not in {"kwds", "self"} | |
] | |
) | |
# Worst case scenario with the same longest param name in the same | |
# row for all columns | |
max_name_length = max(name_lengths) | |
max_column_width = 80 | |
# Output a square table if not too big (since it is easier to read) | |
num_param_names = len(param_names) | |
square_columns = int(ceil(num_param_names**0.5)) | |
columns = min(max_column_width // max_name_length, square_columns) | |
# Compute roughly equal column heights to evenly divide the param names | |
def split_into_equal_parts(n: int, p: int) -> list[int]: | |
return [n // p + 1] * (n % p) + [n // p] * (p - n % p) | |
column_heights = split_into_equal_parts(num_param_names, columns) | |
# Section the param names into columns and compute their widths | |
param_names_columns: list[tuple[str, ...]] = [] | |
column_max_widths: list[int] = [] | |
last_end_idx: int = 0 | |
for ch in column_heights: | |
param_names_columns.append(param_names[last_end_idx : last_end_idx + ch]) | |
column_max_widths.append( | |
max(len(param_name) for param_name in param_names_columns[-1]) | |
) | |
last_end_idx = ch + last_end_idx | |
# Transpose the param name columns into rows to facilitate looping | |
param_names_rows: list[tuple[str, ...]] = [] | |
for li in zip_longest(*param_names_columns, fillvalue=""): | |
param_names_rows.append(li) | |
# Build the table as a string by iterating over and formatting the rows | |
param_names_table: str = "" | |
for param_names_row in param_names_rows: | |
for num, param_name in enumerate(param_names_row): | |
# Set column width based on the longest param in the column | |
max_name_length_column = column_max_widths[num] | |
column_pad = 3 | |
param_names_table += "{:<{}}".format( | |
param_name, max_name_length_column + column_pad | |
) | |
# Insert newlines and spacing after the last element in each row | |
if num == (len(param_names_row) - 1): | |
param_names_table += "\n" | |
return param_names_table | |
def _get_default_error_message( | |
self, | |
errors: ValidationErrorList, | |
) -> str: | |
bullet_points: list[str] = [] | |
errors_by_validator = _group_errors_by_validator(errors) | |
if "enum" in errors_by_validator: | |
for error in errors_by_validator["enum"]: | |
bullet_points.append(f"one of {error.validator_value}") | |
if "type" in errors_by_validator: | |
types = [f"'{err.validator_value}'" for err in errors_by_validator["type"]] | |
point = "of type " | |
if len(types) == 1: | |
point += types[0] | |
elif len(types) == 2: | |
point += f"{types[0]} or {types[1]}" | |
else: | |
point += ", ".join(types[:-1]) + f", or {types[-1]}" | |
bullet_points.append(point) | |
# It should not matter which error is specifically used as they are all | |
# about the same offending instance (i.e. invalid value), so we can just | |
# take the first one | |
error = errors[0] | |
# Add a summary line when parameters are passed an invalid value | |
# For example: "'asdf' is an invalid value for `stack` | |
message = f"'{error.instance}' is an invalid value" | |
if error.absolute_path: | |
message += f" for `{error.absolute_path[-1]}`" | |
# Add bullet points | |
if len(bullet_points) == 0: | |
message += ".\n\n" | |
elif len(bullet_points) == 1: | |
message += f". Valid values are {bullet_points[0]}.\n\n" | |
else: | |
# We don't use .capitalize below to make the first letter uppercase | |
# as that makes the rest of the message lowercase | |
bullet_points = [point[0].upper() + point[1:] for point in bullet_points] | |
message += ". Valid values are:\n\n" | |
message += "\n".join([f"- {point}" for point in bullet_points]) | |
message += "\n\n" | |
# Add unformatted messages of any remaining errors which were not | |
# considered so far. This is not expected to be used but more exists | |
# as a fallback for cases which were not known during development. | |
it = ( | |
"\n".join(e.message for e in errors) | |
for validator, errors in errors_by_validator.items() | |
if validator not in {"enum", "type"} | |
) | |
message += "".join(it) | |
return message | |
_JSON_VT_co = TypeVar( | |
"_JSON_VT_co", | |
Literal["string"], | |
Literal["object"], | |
Literal["array"], | |
covariant=True, | |
) | |
""" | |
One of a subset of JSON Schema `primitive types`_: | |
["string", "object", "array"] | |
.. _primitive types: | |
https://json-schema.org/draft-07/json-schema-validation#rfc.section.6.1.1 | |
""" | |
_TypeMap = TypeAliasType( | |
"_TypeMap", Mapping[Literal["type"], _JSON_VT_co], type_params=(_JSON_VT_co,) | |
) | |
""" | |
A single item JSON Schema using the `type`_ keyword. | |
This may represent **one of**: | |
{"type": "string"} | |
{"type": "object"} | |
{"type": "array"} | |
.. _type: | |
https://json-schema.org/understanding-json-schema/reference/type | |
""" | |
# NOTE: Type checkers want opposing things: | |
# - `mypy` : Covariant type variable "_JSON_VT_co" used in protocol where invariant one is expected [misc] | |
# - `pyright`: Type variable "_JSON_VT_co" used in generic protocol "SchemaLike" should be covariant [reportInvalidTypeVarUse] | |
# Siding with `pyright` as this is consistent with https://github.com/python/typeshed/blob/9e506eb5e8fc2823db8c60ad561b1145ff114947/stdlib/typing.pyi#L690 | |
class SchemaLike(Generic[_JSON_VT_co], Protocol): # type: ignore[misc] | |
""" | |
Represents ``altair`` classes which *may* not derive ``SchemaBase``. | |
Attributes | |
---------- | |
_schema | |
A single item JSON Schema using the `type`_ keyword. | |
Notes | |
----- | |
Should be kept tightly defined to the **minimum** requirements for: | |
- Converting into a form that can be validated by `jsonschema`_. | |
- Avoiding calling ``.to_dict()`` on a class external to ``altair``. | |
- ``_schema`` is more accurately described as a ``ClassVar`` | |
- See `discussion`_ for blocking issue. | |
.. _jsonschema: | |
https://github.com/python-jsonschema/jsonschema | |
.. _type: | |
https://json-schema.org/understanding-json-schema/reference/type | |
.. _discussion: | |
https://github.com/python/typing/discussions/1424 | |
""" | |
_schema: _TypeMap[_JSON_VT_co] | |
def to_dict(self, *args, **kwds) -> Any: ... | |
class ConditionLike(SchemaLike[Literal["object"]], Protocol): | |
""" | |
Represents the wrapped state of a conditional encoding or property. | |
Attributes | |
---------- | |
condition | |
One or more (predicate, statement) pairs which each form a condition. | |
Notes | |
----- | |
- Can be extended with additional conditions. | |
- *Does not* define a default value, but can be finalized with one. | |
""" | |
condition: Any | |
_schema: _TypeMap[Literal["object"]] = {"type": "object"} | |
class UndefinedType: | |
"""A singleton object for marking undefined parameters.""" | |
__instance = None | |
def __new__(cls, *args, **kwargs) -> Self: | |
if not isinstance(cls.__instance, cls): | |
cls.__instance = object.__new__(cls, *args, **kwargs) | |
return cls.__instance | |
def __repr__(self) -> str: | |
return "Undefined" | |
Undefined = UndefinedType() | |
T = TypeVar("T") | |
Optional: TypeAlias = Union[T, UndefinedType] | |
"""One of ``T`` specified type(s), or the ``Undefined`` singleton. | |
Examples | |
-------- | |
The parameters ``short``, ``long`` accept the same range of types:: | |
# ruff: noqa: UP006, UP007 | |
from altair.typing import Optional | |
def func_1( | |
short: Optional[str | bool | float | dict[str, Any] | SchemaBase] = Undefined, | |
long: Union[ | |
str, bool, float, Dict[str, Any], SchemaBase, UndefinedType | |
] = Undefined, | |
): ... | |
This is distinct from `typing.Optional <https://typing.readthedocs.io/en/latest/spec/historical.html#union-and-optional>`__. | |
``altair.typing.Optional`` treats ``None`` like any other type:: | |
# ruff: noqa: UP006, UP007 | |
from altair.typing import Optional | |
def func_2( | |
short: Optional[str | float | dict[str, Any] | None | SchemaBase] = Undefined, | |
long: Union[ | |
str, float, Dict[str, Any], None, SchemaBase, UndefinedType | |
] = Undefined, | |
): ... | |
""" | |
def is_undefined(obj: Any) -> TypeIs[UndefinedType]: | |
""" | |
Type-safe singleton check for `UndefinedType`. | |
Notes | |
----- | |
- Using `obj is Undefined` does not narrow from `UndefinedType` in a union. | |
- Due to the assumption that other `UndefinedType`'s could exist. | |
- Current [typing spec advises](https://typing.readthedocs.io/en/latest/spec/concepts.html#support-for-singleton-types-in-unions) using an `Enum`. | |
- Otherwise, requires an explicit guard to inform the type checker. | |
""" | |
return obj is Undefined | |
def _shallow_copy(obj: _CopyImpl) -> _CopyImpl: ... | |
def _shallow_copy(obj: Any) -> Any: ... | |
def _shallow_copy(obj: _CopyImpl | Any) -> _CopyImpl | Any: | |
if isinstance(obj, SchemaBase): | |
return obj.copy(deep=False) | |
elif isinstance(obj, (list, dict)): | |
return obj.copy() | |
else: | |
return obj | |
def _deep_copy(obj: _CopyImpl, by_ref: set[str]) -> _CopyImpl: ... | |
def _deep_copy(obj: Any, by_ref: set[str]) -> Any: ... | |
def _deep_copy(obj: _CopyImpl | Any, by_ref: set[str]) -> _CopyImpl | Any: | |
copy = partial(_deep_copy, by_ref=by_ref) | |
if isinstance(obj, SchemaBase): | |
if copier := getattr(obj, "__deepcopy__", None): | |
with debug_mode(False): | |
return copier(obj) | |
args = (copy(arg) for arg in obj._args) | |
kwds = {k: (copy(v) if k not in by_ref else v) for k, v in obj._kwds.items()} | |
with debug_mode(False): | |
return obj.__class__(*args, **kwds) | |
elif isinstance(obj, list): | |
return [copy(v) for v in obj] | |
elif isinstance(obj, dict): | |
return {k: (copy(v) if k not in by_ref else v) for k, v in obj.items()} | |
else: | |
return obj | |
class SchemaBase: | |
""" | |
Base class for schema wrappers. | |
Each derived class should set the _schema class attribute (and optionally | |
the _rootschema class attribute) which is used for validation. | |
""" | |
_schema: ClassVar[dict[str, Any] | Any] = None | |
_rootschema: ClassVar[dict[str, Any] | None] = None | |
_class_is_valid_at_instantiation: ClassVar[bool] = True | |
def __init__(self, *args: Any, **kwds: Any) -> None: | |
# Two valid options for initialization, which should be handled by | |
# derived classes: | |
# - a single arg with no kwds, for, e.g. {'type': 'string'} | |
# - zero args with zero or more kwds for {'type': 'object'} | |
if self._schema is None: | |
msg = ( | |
f"Cannot instantiate object of type {self.__class__}: " | |
"_schema class attribute is not defined." | |
"" | |
) | |
raise ValueError(msg) | |
if kwds: | |
assert len(args) == 0 | |
else: | |
assert len(args) in {0, 1} | |
# use object.__setattr__ because we override setattr below. | |
object.__setattr__(self, "_args", args) | |
object.__setattr__(self, "_kwds", kwds) | |
if DEBUG_MODE and self._class_is_valid_at_instantiation: | |
self.to_dict(validate=True) | |
def copy( | |
self, deep: bool | Iterable[Any] = True, ignore: list[str] | None = None | |
) -> Self: | |
""" | |
Return a copy of the object. | |
Parameters | |
---------- | |
deep : boolean or list, optional | |
If True (default) then return a deep copy of all dict, list, and | |
SchemaBase objects within the object structure. | |
If False, then only copy the top object. | |
If a list or iterable, then only copy the listed attributes. | |
ignore : list, optional | |
A list of keys for which the contents should not be copied, but | |
only stored by reference. | |
""" | |
if deep is True: | |
return cast("Self", _deep_copy(self, set(ignore) if ignore else set())) | |
with debug_mode(False): | |
copy = self.__class__(*self._args, **self._kwds) | |
if _is_iterable(deep): | |
for attr in deep: | |
copy[attr] = _shallow_copy(copy._get(attr)) | |
return copy | |
def _get(self, attr, default=Undefined): | |
"""Get an attribute, returning default if not present.""" | |
attr = self._kwds.get(attr, Undefined) | |
if attr is Undefined: | |
attr = default | |
return attr | |
def __getattr__(self, attr): | |
# reminder: getattr is called after the normal lookups | |
if attr == "_kwds": | |
raise AttributeError() | |
if attr in self._kwds: | |
return self._kwds[attr] | |
else: | |
try: | |
_getattr = super().__getattr__ # pyright: ignore[reportAttributeAccessIssue] | |
except AttributeError: | |
_getattr = super().__getattribute__ | |
return _getattr(attr) | |
def __setattr__(self, item, val) -> None: | |
self._kwds[item] = val | |
def __getitem__(self, item): | |
return self._kwds[item] | |
def __setitem__(self, item, val) -> None: | |
self._kwds[item] = val | |
def __repr__(self) -> str: | |
name = type(self).__name__ | |
if kwds := self._kwds: | |
it = (f"{k}: {v!r}" for k, v in sorted(kwds.items()) if v is not Undefined) | |
args = ",\n".join(it).replace("\n", "\n ") | |
LB, RB = "{", "}" | |
return f"{name}({LB}\n {args}\n{RB})" | |
else: | |
return f"{name}({self._args[0]!r})" | |
def __eq__(self, other: Any) -> bool: | |
return ( | |
type(self) is type(other) | |
and self._args == other._args | |
and self._kwds == other._kwds | |
) | |
def to_dict( | |
self, | |
validate: bool = True, | |
*, | |
ignore: list[str] | None = None, | |
context: dict[str, Any] | None = None, | |
) -> dict[str, Any]: | |
""" | |
Return a dictionary representation of the object. | |
Parameters | |
---------- | |
validate : bool, optional | |
If True (default), then validate the result against the schema. | |
ignore : list[str], optional | |
A list of keys to ignore. | |
context : dict[str, Any], optional | |
A context dictionary. | |
Raises | |
------ | |
SchemaValidationError : | |
If ``validate`` and the result does not conform to the schema. | |
Notes | |
----- | |
- ``ignore``, ``context`` are usually not needed to be specified as a user. | |
- *Technical*: ``ignore`` will **not** be passed to child :meth:`.to_dict()`. | |
""" | |
context = context or {} | |
ignore = ignore or [] | |
opts = _get_optional_modules(np_opt="numpy", pd_opt="pandas") | |
if self._args and not self._kwds: | |
kwds = self._args[0] | |
elif not self._args: | |
kwds = self._kwds.copy() | |
exclude = {*ignore, "shorthand"} | |
if parsed := context.pop("parsed_shorthand", None): | |
kwds = _replace_parsed_shorthand(parsed, kwds) | |
kwds = {k: v for k, v in kwds.items() if k not in exclude} | |
if (mark := kwds.get("mark")) and isinstance(mark, str): | |
kwds["mark"] = {"type": mark} | |
else: | |
msg = f"{type(self)} instance has both a value and properties : cannot serialize to dict" | |
raise ValueError(msg) | |
result = _todict(kwds, context=context, **opts) | |
if validate: | |
# NOTE: Don't raise `from err`, see `SchemaValidationError` doc | |
try: | |
self.validate(result) | |
except jsonschema.ValidationError as err: | |
raise SchemaValidationError(self, err) from None | |
return result | |
def to_json( | |
self, | |
validate: bool = True, | |
indent: int | str | None = 2, | |
sort_keys: bool = True, | |
*, | |
ignore: list[str] | None = None, | |
context: dict[str, Any] | None = None, | |
**kwargs, | |
) -> str: | |
""" | |
Emit the JSON representation for this object as a string. | |
Parameters | |
---------- | |
validate : bool, optional | |
If True (default), then validate the result against the schema. | |
indent : int, optional | |
The number of spaces of indentation to use. The default is 2. | |
sort_keys : bool, optional | |
If True (default), sort keys in the output. | |
ignore : list[str], optional | |
A list of keys to ignore. | |
context : dict[str, Any], optional | |
A context dictionary. | |
**kwargs | |
Additional keyword arguments are passed to ``json.dumps()`` | |
Raises | |
------ | |
SchemaValidationError : | |
If ``validate`` and the result does not conform to the schema. | |
Notes | |
----- | |
- ``ignore``, ``context`` are usually not needed to be specified as a user. | |
- *Technical*: ``ignore`` will **not** be passed to child :meth:`.to_dict()`. | |
""" | |
if ignore is None: | |
ignore = [] | |
if context is None: | |
context = {} | |
dct = self.to_dict(validate=validate, ignore=ignore, context=context) | |
return json.dumps(dct, indent=indent, sort_keys=sort_keys, **kwargs) | |
def _default_wrapper_classes(cls) -> Iterator[type[SchemaBase]]: | |
"""Return the set of classes used within cls.from_dict().""" | |
return _subclasses(SchemaBase) | |
def from_dict( | |
cls: type[TSchemaBase], dct: dict[str, Any], validate: bool = True | |
) -> TSchemaBase: | |
""" | |
Construct class from a dictionary representation. | |
Parameters | |
---------- | |
dct : dictionary | |
The dict from which to construct the class | |
validate : boolean | |
If True (default), then validate the input against the schema. | |
Raises | |
------ | |
jsonschema.ValidationError : | |
If ``validate`` and ``dct`` does not conform to the schema | |
""" | |
if validate: | |
cls.validate(dct) | |
converter = _FromDict(cls._default_wrapper_classes()) | |
return converter.from_dict(dct, cls) | |
def from_json( | |
cls, | |
json_string: str, | |
validate: bool = True, | |
**kwargs: Any, | |
# Type hints for this method would get rather complicated | |
# if we want to provide a more specific return type | |
) -> ChartType: | |
""" | |
Instantiate the object from a valid JSON string. | |
Parameters | |
---------- | |
json_string : string | |
The string containing a valid JSON chart specification. | |
validate : boolean | |
If True (default), then validate the input against the schema. | |
**kwargs : | |
Additional keyword arguments are passed to json.loads | |
Returns | |
------- | |
chart : Chart object | |
The altair Chart object built from the specification. | |
""" | |
dct: dict[str, Any] = json.loads(json_string, **kwargs) | |
return cls.from_dict(dct, validate=validate) # type: ignore[return-value] | |
def validate( | |
cls, instance: dict[str, Any], schema: dict[str, Any] | None = None | |
) -> None: | |
"""Validate the instance against the class schema in the context of the rootschema.""" | |
if schema is None: | |
schema = cls._schema | |
# For the benefit of mypy | |
assert schema is not None | |
validate_jsonschema(instance, schema, rootschema=cls._rootschema or cls._schema) | |
def resolve_references(cls, schema: dict[str, Any] | None = None) -> dict[str, Any]: | |
"""Resolve references in the context of this object's schema or root schema.""" | |
schema_to_pass = schema or cls._schema | |
# For the benefit of mypy | |
assert schema_to_pass is not None | |
return _resolve_references( | |
schema=schema_to_pass, | |
rootschema=(cls._rootschema or cls._schema or schema), | |
) | |
def validate_property( | |
cls, name: str, value: Any, schema: dict[str, Any] | None = None | |
) -> None: | |
"""Validate a property against property schema in the context of the rootschema.""" | |
opts = _get_optional_modules(np_opt="numpy", pd_opt="pandas") | |
value = _todict(value, context={}, **opts) | |
props = cls.resolve_references(schema or cls._schema).get("properties", {}) | |
validate_jsonschema( | |
value, props.get(name, {}), rootschema=cls._rootschema or cls._schema | |
) | |
def __dir__(self) -> list[str]: | |
return sorted(chain(super().__dir__(), self._kwds)) | |
def _get_optional_modules(**modules: str) -> dict[str, _OptionalModule]: | |
""" | |
Returns packages only if they have already been imported - otherwise they return `None`. | |
This is useful for `isinstance` checks. | |
For example, if `pandas` has not been imported, then an object is | |
definitely not a `pandas.Timestamp`. | |
Parameters | |
---------- | |
**modules | |
Keyword-only binding from `{alias: module_name}`. | |
Examples | |
-------- | |
>>> import pandas as pd # doctest: +SKIP | |
>>> import polars as pl # doctest: +SKIP | |
>>> from altair.utils.schemapi import _get_optional_modules # doctest: +SKIP | |
>>> | |
>>> _get_optional_modules(pd="pandas", pl="polars", ibis="ibis") # doctest: +SKIP | |
{ | |
"pd": <module 'pandas' from '...'>, | |
"pl": <module 'polars' from '...'>, | |
"ibis": None, | |
} | |
If the user later imports ``ibis``, it would appear in subsequent calls. | |
>>> import ibis # doctest: +SKIP | |
>>> | |
>>> _get_optional_modules(ibis="ibis") # doctest: +SKIP | |
{ | |
"ibis": <module 'ibis' from '...'>, | |
} | |
""" | |
return {k: sys.modules.get(v) for k, v in modules.items()} | |
def _replace_parsed_shorthand( | |
parsed_shorthand: dict[str, Any], kwds: dict[str, Any] | |
) -> dict[str, Any]: | |
""" | |
`parsed_shorthand` is added by `FieldChannelMixin`. | |
It's used below to replace shorthand with its long form equivalent | |
`parsed_shorthand` is removed from `context` if it exists so that it is | |
not passed to child `to_dict` function calls. | |
""" | |
# Prevent that pandas categorical data is automatically sorted | |
# when a non-ordinal data type is specifed manually | |
# or if the encoding channel does not support sorting | |
if "sort" in parsed_shorthand and ( | |
"sort" not in kwds or kwds["type"] not in {"ordinal", Undefined} | |
): | |
parsed_shorthand.pop("sort") | |
kwds.update( | |
(k, v) | |
for k, v in parsed_shorthand.items() | |
if kwds.get(k, Undefined) is Undefined | |
) | |
return kwds | |
TSchemaBase = TypeVar("TSchemaBase", bound=SchemaBase) | |
_CopyImpl = TypeVar("_CopyImpl", SchemaBase, dict[Any, Any], list[Any]) | |
""" | |
Types which have an implementation in ``SchemaBase.copy()``. | |
All other types are returned **by reference**. | |
""" | |
def _is_dict(obj: Any | dict[Any, Any]) -> TypeIs[dict[Any, Any]]: | |
return isinstance(obj, dict) | |
def _is_list(obj: Any | list[Any]) -> TypeIs[list[Any]]: | |
return isinstance(obj, list) | |
def _is_iterable( | |
obj: Any, *, exclude: type | tuple[type, ...] = (str, bytes) | |
) -> TypeIs[Iterable[Any]]: | |
return not isinstance(obj, exclude) and isinstance(obj, Iterable) | |
def _passthrough(*args: Any, **kwds: Any) -> Any | dict[str, Any]: | |
return args[0] if args else kwds | |
class _FromDict: | |
""" | |
Class used to construct SchemaBase class hierarchies from a dict. | |
The primary purpose of using this class is to be able to build a hash table | |
that maps schemas to their wrapper classes. The candidate classes are | |
specified in the ``wrapper_classes`` positional-only argument to the constructor. | |
""" | |
_hash_exclude_keys = ("definitions", "title", "description", "$schema", "id") | |
def __init__(self, wrapper_classes: Iterable[type[SchemaBase]], /) -> None: | |
# Create a mapping of a schema hash to a list of matching classes | |
# This lets us quickly determine the correct class to construct | |
self.class_dict: dict[int, list[type[SchemaBase]]] = defaultdict(list) | |
for tp in wrapper_classes: | |
if tp._schema is not None: | |
self.class_dict[self.hash_schema(tp._schema)].append(tp) | |
def hash_schema(cls, schema: dict[str, Any], use_json: bool = True) -> int: | |
""" | |
Compute a python hash for a nested dictionary which properly handles dicts, lists, sets, and tuples. | |
At the top level, the function excludes from the hashed schema all keys | |
listed in `exclude_keys`. | |
This implements two methods: one based on conversion to JSON, and one based | |
on recursive conversions of unhashable to hashable types; the former seems | |
to be slightly faster in several benchmarks. | |
""" | |
if cls._hash_exclude_keys and isinstance(schema, dict): | |
schema = { | |
key: val | |
for key, val in schema.items() | |
if key not in cls._hash_exclude_keys | |
} | |
if use_json: | |
s = json.dumps(schema, sort_keys=True) | |
return hash(s) | |
else: | |
def _freeze(val): | |
if isinstance(val, dict): | |
return frozenset((k, _freeze(v)) for k, v in val.items()) | |
elif isinstance(val, set): | |
return frozenset(map(_freeze, val)) | |
elif isinstance(val, (list, tuple)): | |
return tuple(map(_freeze, val)) | |
else: | |
return val | |
return hash(_freeze(schema)) | |
def from_dict( | |
self, | |
dct: TSchemaBase, | |
tp: None = ..., | |
schema: None = ..., | |
rootschema: None = ..., | |
default_class: Any = ..., | |
) -> TSchemaBase: ... | |
def from_dict( | |
self, | |
dct: dict[str, Any] | list[dict[str, Any]], | |
tp: Any = ..., | |
schema: Any = ..., | |
rootschema: Any = ..., | |
default_class: type[TSchemaBase] = ..., # pyright: ignore[reportInvalidTypeVarUse] | |
) -> TSchemaBase: ... | |
def from_dict( | |
self, | |
dct: dict[str, Any], | |
tp: None = ..., | |
schema: dict[str, Any] = ..., | |
rootschema: None = ..., | |
default_class: Any = ..., | |
) -> SchemaBase: ... | |
def from_dict( | |
self, | |
dct: dict[str, Any], | |
tp: type[TSchemaBase], | |
schema: None = ..., | |
rootschema: None = ..., | |
default_class: Any = ..., | |
) -> TSchemaBase: ... | |
def from_dict( | |
self, | |
dct: dict[str, Any] | list[dict[str, Any]], | |
tp: type[TSchemaBase], | |
schema: dict[str, Any], | |
rootschema: dict[str, Any] | None = ..., | |
default_class: Any = ..., | |
) -> Never: ... | |
def from_dict( | |
self, | |
dct: dict[str, Any] | list[dict[str, Any]] | TSchemaBase, | |
tp: type[TSchemaBase] | None = None, | |
schema: dict[str, Any] | None = None, | |
rootschema: dict[str, Any] | None = None, | |
default_class: Any = _passthrough, | |
) -> TSchemaBase | SchemaBase: | |
"""Construct an object from a dict representation.""" | |
target_tp: Any | |
current_schema: dict[str, Any] | |
if isinstance(dct, SchemaBase): | |
return dct | |
elif tp is not None: | |
current_schema = tp._schema | |
root_schema: dict[str, Any] = rootschema or tp._rootschema or current_schema | |
target_tp = tp | |
elif schema is not None: | |
# If there are multiple matches, we use the first one in the dict. | |
# Our class dict is constructed breadth-first from top to bottom, | |
# so the first class that matches is the most general match. | |
current_schema = schema | |
root_schema = rootschema or current_schema | |
matches = self.class_dict[self.hash_schema(current_schema)] | |
target_tp = matches[0] if matches else default_class | |
else: | |
msg = "Must provide either `tp` or `schema`, but not both." | |
raise ValueError(msg) | |
from_dict = partial(self.from_dict, rootschema=root_schema) | |
# Can also return a list? | |
resolved = _resolve_references(current_schema, root_schema) | |
if "anyOf" in resolved or "oneOf" in resolved: | |
schemas = resolved.get("anyOf", []) + resolved.get("oneOf", []) | |
for possible in schemas: | |
try: | |
validate_jsonschema(dct, possible, rootschema=root_schema) | |
except jsonschema.ValidationError: | |
continue | |
else: | |
return from_dict(dct, schema=possible, default_class=target_tp) | |
if _is_dict(dct): | |
# TODO: handle schemas for additionalProperties/patternProperties | |
props: dict[str, Any] = resolved.get("properties", {}) | |
kwds = { | |
k: (from_dict(v, schema=props[k]) if k in props else v) | |
for k, v in dct.items() | |
} | |
return target_tp(**kwds) | |
elif _is_list(dct): | |
item_schema: dict[str, Any] = resolved.get("items", {}) | |
return target_tp([from_dict(k, schema=item_schema) for k in dct]) | |
else: | |
# NOTE: Unsure what is valid here | |
return target_tp(dct) | |
class _PropertySetter: | |
def __init__(self, prop: str, schema: dict[str, Any]) -> None: | |
self.prop = prop | |
self.schema = schema | |
def __get__(self, obj, cls): | |
self.obj = obj | |
self.cls = cls | |
# The docs from the encoding class parameter (e.g. `bin` in X, Color, | |
# etc); this provides a general description of the parameter. | |
self.__doc__ = self.schema["description"].replace("__", "**") | |
property_name = f"{self.prop}"[0].upper() + f"{self.prop}"[1:] | |
if hasattr(vegalite, property_name): | |
altair_prop = getattr(vegalite, property_name) | |
# Add the docstring from the helper class (e.g. `BinParams`) so | |
# that all the parameter names of the helper class are included in | |
# the final docstring | |
parameter_index = altair_prop.__doc__.find("Parameters\n") | |
if parameter_index > -1: | |
self.__doc__ = ( | |
altair_prop.__doc__[:parameter_index].replace(" ", "") | |
+ self.__doc__ | |
+ textwrap.dedent( | |
f"\n\n {altair_prop.__doc__[parameter_index:]}" | |
) | |
) | |
# For short docstrings such as Aggregate, Stack, et | |
else: | |
self.__doc__ = ( | |
altair_prop.__doc__.replace(" ", "") + "\n" + self.__doc__ | |
) | |
# Add signatures and tab completion for the method and parameter names | |
self.__signature__ = inspect.signature(altair_prop) | |
self.__wrapped__ = inspect.getfullargspec(altair_prop) | |
self.__name__ = altair_prop.__name__ | |
else: | |
# It seems like bandPosition is the only parameter that doesn't | |
# have a helper class. | |
pass | |
return self | |
def __call__(self, *args: Any, **kwargs: Any): | |
obj = self.obj.copy() | |
# TODO: use schema to validate | |
obj[self.prop] = args[0] if args else kwargs | |
return obj | |
def with_property_setters(cls: type[TSchemaBase]) -> type[TSchemaBase]: | |
"""Decorator to add property setters to a Schema class.""" | |
schema = cls.resolve_references() | |
for prop, propschema in schema.get("properties", {}).items(): | |
setattr(cls, prop, _PropertySetter(prop, propschema)) | |
return cls | |