Spaces:
Runtime error
Runtime error
| """ | |
| Utility routines | |
| """ | |
| from collections.abc import Mapping, MutableMapping | |
| from copy import deepcopy | |
| import json | |
| import itertools | |
| import re | |
| import sys | |
| import traceback | |
| import warnings | |
| from typing import ( | |
| Callable, | |
| TypeVar, | |
| Any, | |
| Union, | |
| Dict, | |
| Optional, | |
| Tuple, | |
| Sequence, | |
| Type, | |
| cast, | |
| ) | |
| from types import ModuleType | |
| import jsonschema | |
| import pandas as pd | |
| import numpy as np | |
| from pandas.api.types import infer_dtype | |
| from altair.utils.schemapi import SchemaBase | |
| from altair.utils._dfi_types import Column, DtypeKind, DataFrame as DfiDataFrame | |
| if sys.version_info >= (3, 10): | |
| from typing import ParamSpec | |
| else: | |
| from typing_extensions import ParamSpec | |
| from typing import Literal, Protocol, TYPE_CHECKING | |
| if TYPE_CHECKING: | |
| from pandas.core.interchange.dataframe_protocol import Column as PandasColumn | |
| V = TypeVar("V") | |
| P = ParamSpec("P") | |
| class DataFrameLike(Protocol): | |
| def __dataframe__(self, *args, **kwargs) -> DfiDataFrame: | |
| ... | |
| TYPECODE_MAP = { | |
| "ordinal": "O", | |
| "nominal": "N", | |
| "quantitative": "Q", | |
| "temporal": "T", | |
| "geojson": "G", | |
| } | |
| INV_TYPECODE_MAP = {v: k for k, v in TYPECODE_MAP.items()} | |
| # aggregates from vega-lite version 4.6.0 | |
| AGGREGATES = [ | |
| "argmax", | |
| "argmin", | |
| "average", | |
| "count", | |
| "distinct", | |
| "max", | |
| "mean", | |
| "median", | |
| "min", | |
| "missing", | |
| "product", | |
| "q1", | |
| "q3", | |
| "ci0", | |
| "ci1", | |
| "stderr", | |
| "stdev", | |
| "stdevp", | |
| "sum", | |
| "valid", | |
| "values", | |
| "variance", | |
| "variancep", | |
| ] | |
| # window aggregates from vega-lite version 4.6.0 | |
| WINDOW_AGGREGATES = [ | |
| "row_number", | |
| "rank", | |
| "dense_rank", | |
| "percent_rank", | |
| "cume_dist", | |
| "ntile", | |
| "lag", | |
| "lead", | |
| "first_value", | |
| "last_value", | |
| "nth_value", | |
| ] | |
| # timeUnits from vega-lite version 4.17.0 | |
| TIMEUNITS = [ | |
| "year", | |
| "quarter", | |
| "month", | |
| "week", | |
| "day", | |
| "dayofyear", | |
| "date", | |
| "hours", | |
| "minutes", | |
| "seconds", | |
| "milliseconds", | |
| "yearquarter", | |
| "yearquartermonth", | |
| "yearmonth", | |
| "yearmonthdate", | |
| "yearmonthdatehours", | |
| "yearmonthdatehoursminutes", | |
| "yearmonthdatehoursminutesseconds", | |
| "yearweek", | |
| "yearweekday", | |
| "yearweekdayhours", | |
| "yearweekdayhoursminutes", | |
| "yearweekdayhoursminutesseconds", | |
| "yeardayofyear", | |
| "quartermonth", | |
| "monthdate", | |
| "monthdatehours", | |
| "monthdatehoursminutes", | |
| "monthdatehoursminutesseconds", | |
| "weekday", | |
| "weeksdayhours", | |
| "weekdayhoursminutes", | |
| "weekdayhoursminutesseconds", | |
| "dayhours", | |
| "dayhoursminutes", | |
| "dayhoursminutesseconds", | |
| "hoursminutes", | |
| "hoursminutesseconds", | |
| "minutesseconds", | |
| "secondsmilliseconds", | |
| "utcyear", | |
| "utcquarter", | |
| "utcmonth", | |
| "utcweek", | |
| "utcday", | |
| "utcdayofyear", | |
| "utcdate", | |
| "utchours", | |
| "utcminutes", | |
| "utcseconds", | |
| "utcmilliseconds", | |
| "utcyearquarter", | |
| "utcyearquartermonth", | |
| "utcyearmonth", | |
| "utcyearmonthdate", | |
| "utcyearmonthdatehours", | |
| "utcyearmonthdatehoursminutes", | |
| "utcyearmonthdatehoursminutesseconds", | |
| "utcyearweek", | |
| "utcyearweekday", | |
| "utcyearweekdayhours", | |
| "utcyearweekdayhoursminutes", | |
| "utcyearweekdayhoursminutesseconds", | |
| "utcyeardayofyear", | |
| "utcquartermonth", | |
| "utcmonthdate", | |
| "utcmonthdatehours", | |
| "utcmonthdatehoursminutes", | |
| "utcmonthdatehoursminutesseconds", | |
| "utcweekday", | |
| "utcweeksdayhours", | |
| "utcweekdayhoursminutes", | |
| "utcweekdayhoursminutesseconds", | |
| "utcdayhours", | |
| "utcdayhoursminutes", | |
| "utcdayhoursminutesseconds", | |
| "utchoursminutes", | |
| "utchoursminutesseconds", | |
| "utcminutesseconds", | |
| "utcsecondsmilliseconds", | |
| ] | |
| InferredVegaLiteType = Literal["ordinal", "nominal", "quantitative", "temporal"] | |
| def infer_vegalite_type( | |
| data: object, | |
| ) -> Union[InferredVegaLiteType, Tuple[InferredVegaLiteType, list]]: | |
| """ | |
| From an array-like input, infer the correct vega typecode | |
| ('ordinal', 'nominal', 'quantitative', or 'temporal') | |
| Parameters | |
| ---------- | |
| data: object | |
| """ | |
| typ = infer_dtype(data, skipna=False) | |
| if typ in [ | |
| "floating", | |
| "mixed-integer-float", | |
| "integer", | |
| "mixed-integer", | |
| "complex", | |
| ]: | |
| return "quantitative" | |
| elif typ == "categorical" and hasattr(data, "cat") and data.cat.ordered: | |
| return ("ordinal", data.cat.categories.tolist()) | |
| elif typ in ["string", "bytes", "categorical", "boolean", "mixed", "unicode"]: | |
| return "nominal" | |
| elif typ in [ | |
| "datetime", | |
| "datetime64", | |
| "timedelta", | |
| "timedelta64", | |
| "date", | |
| "time", | |
| "period", | |
| ]: | |
| return "temporal" | |
| else: | |
| warnings.warn( | |
| "I don't know how to infer vegalite type from '{}'. " | |
| "Defaulting to nominal.".format(typ), | |
| stacklevel=1, | |
| ) | |
| return "nominal" | |
| def merge_props_geom(feat: dict) -> dict: | |
| """ | |
| Merge properties with geometry | |
| * Overwrites 'type' and 'geometry' entries if existing | |
| """ | |
| geom = {k: feat[k] for k in ("type", "geometry")} | |
| try: | |
| feat["properties"].update(geom) | |
| props_geom = feat["properties"] | |
| except (AttributeError, KeyError): | |
| # AttributeError when 'properties' equals None | |
| # KeyError when 'properties' is non-existing | |
| props_geom = geom | |
| return props_geom | |
| def sanitize_geo_interface(geo: MutableMapping) -> dict: | |
| """Santize a geo_interface to prepare it for serialization. | |
| * Make a copy | |
| * Convert type array or _Array to list | |
| * Convert tuples to lists (using json.loads/dumps) | |
| * Merge properties with geometry | |
| """ | |
| geo = deepcopy(geo) | |
| # convert type _Array or array to list | |
| for key in geo.keys(): | |
| if str(type(geo[key]).__name__).startswith(("_Array", "array")): | |
| geo[key] = geo[key].tolist() | |
| # convert (nested) tuples to lists | |
| geo_dct: dict = json.loads(json.dumps(geo)) | |
| # sanitize features | |
| if geo_dct["type"] == "FeatureCollection": | |
| geo_dct = geo_dct["features"] | |
| if len(geo_dct) > 0: | |
| for idx, feat in enumerate(geo_dct): | |
| geo_dct[idx] = merge_props_geom(feat) | |
| elif geo_dct["type"] == "Feature": | |
| geo_dct = merge_props_geom(geo_dct) | |
| else: | |
| geo_dct = {"type": "Feature", "geometry": geo_dct} | |
| return geo_dct | |
| def numpy_is_subtype(dtype: Any, subtype: Any) -> bool: | |
| try: | |
| return np.issubdtype(dtype, subtype) | |
| except (NotImplementedError, TypeError): | |
| return False | |
| def sanitize_dataframe(df: pd.DataFrame) -> pd.DataFrame: # noqa: C901 | |
| """Sanitize a DataFrame to prepare it for serialization. | |
| * Make a copy | |
| * Convert RangeIndex columns to strings | |
| * Raise ValueError if column names are not strings | |
| * Raise ValueError if it has a hierarchical index. | |
| * Convert categoricals to strings. | |
| * Convert np.bool_ dtypes to Python bool objects | |
| * Convert np.int dtypes to Python int objects | |
| * Convert floats to objects and replace NaNs/infs with None. | |
| * Convert DateTime dtypes into appropriate string representations | |
| * Convert Nullable integers to objects and replace NaN with None | |
| * Convert Nullable boolean to objects and replace NaN with None | |
| * convert dedicated string column to objects and replace NaN with None | |
| * Raise a ValueError for TimeDelta dtypes | |
| """ | |
| df = df.copy() | |
| if isinstance(df.columns, pd.RangeIndex): | |
| df.columns = df.columns.astype(str) | |
| for col_name in df.columns: | |
| if not isinstance(col_name, str): | |
| raise ValueError( | |
| "Dataframe contains invalid column name: {0!r}. " | |
| "Column names must be strings".format(col_name) | |
| ) | |
| if isinstance(df.index, pd.MultiIndex): | |
| raise ValueError("Hierarchical indices not supported") | |
| if isinstance(df.columns, pd.MultiIndex): | |
| raise ValueError("Hierarchical indices not supported") | |
| def to_list_if_array(val): | |
| if isinstance(val, np.ndarray): | |
| return val.tolist() | |
| else: | |
| return val | |
| for dtype_item in df.dtypes.items(): | |
| # We know that the column names are strings from the isinstance check | |
| # further above but mypy thinks it is of type Hashable and therefore does not | |
| # let us assign it to the col_name variable which is already of type str. | |
| col_name = cast(str, dtype_item[0]) | |
| dtype = dtype_item[1] | |
| dtype_name = str(dtype) | |
| if dtype_name == "category": | |
| # Work around bug in to_json for categorical types in older versions | |
| # of pandas as they do not properly convert NaN values to null in to_json. | |
| # We can probably remove this part once we require pandas >= 1.0 | |
| col = df[col_name].astype(object) | |
| df[col_name] = col.where(col.notnull(), None) | |
| elif dtype_name == "string": | |
| # dedicated string datatype (since 1.0) | |
| # https://pandas.pydata.org/pandas-docs/version/1.0.0/whatsnew/v1.0.0.html#dedicated-string-data-type | |
| col = df[col_name].astype(object) | |
| df[col_name] = col.where(col.notnull(), None) | |
| elif dtype_name == "bool": | |
| # convert numpy bools to objects; np.bool is not JSON serializable | |
| df[col_name] = df[col_name].astype(object) | |
| elif dtype_name == "boolean": | |
| # dedicated boolean datatype (since 1.0) | |
| # https://pandas.io/docs/user_guide/boolean.html | |
| col = df[col_name].astype(object) | |
| df[col_name] = col.where(col.notnull(), None) | |
| elif dtype_name.startswith("datetime") or dtype_name.startswith("timestamp"): | |
| # Convert datetimes to strings. This needs to be a full ISO string | |
| # with time, which is why we cannot use ``col.astype(str)``. | |
| # This is because Javascript parses date-only times in UTC, but | |
| # parses full ISO-8601 dates as local time, and dates in Vega and | |
| # Vega-Lite are displayed in local time by default. | |
| # (see https://github.com/altair-viz/altair/issues/1027) | |
| df[col_name] = ( | |
| df[col_name].apply(lambda x: x.isoformat()).replace("NaT", "") | |
| ) | |
| elif dtype_name.startswith("timedelta"): | |
| raise ValueError( | |
| 'Field "{col_name}" has type "{dtype}" which is ' | |
| "not supported by Altair. Please convert to " | |
| "either a timestamp or a numerical value." | |
| "".format(col_name=col_name, dtype=dtype) | |
| ) | |
| elif dtype_name.startswith("geometry"): | |
| # geopandas >=0.6.1 uses the dtype geometry. Continue here | |
| # otherwise it will give an error on np.issubdtype(dtype, np.integer) | |
| continue | |
| elif ( | |
| dtype_name | |
| in { | |
| "Int8", | |
| "Int16", | |
| "Int32", | |
| "Int64", | |
| "UInt8", | |
| "UInt16", | |
| "UInt32", | |
| "UInt64", | |
| "Float32", | |
| "Float64", | |
| } | |
| ): # nullable integer datatypes (since 24.0) and nullable float datatypes (since 1.2.0) | |
| # https://pandas.pydata.org/pandas-docs/version/0.25/whatsnew/v0.24.0.html#optional-integer-na-support | |
| col = df[col_name].astype(object) | |
| df[col_name] = col.where(col.notnull(), None) | |
| elif numpy_is_subtype(dtype, np.integer): | |
| # convert integers to objects; np.int is not JSON serializable | |
| df[col_name] = df[col_name].astype(object) | |
| elif numpy_is_subtype(dtype, np.floating): | |
| # For floats, convert to Python float: np.float is not JSON serializable | |
| # Also convert NaN/inf values to null, as they are not JSON serializable | |
| col = df[col_name] | |
| bad_values = col.isnull() | np.isinf(col) | |
| df[col_name] = col.astype(object).where(~bad_values, None) | |
| elif dtype == object: | |
| # Convert numpy arrays saved as objects to lists | |
| # Arrays are not JSON serializable | |
| col = df[col_name].astype(object).apply(to_list_if_array) | |
| df[col_name] = col.where(col.notnull(), None) | |
| return df | |
| def sanitize_arrow_table(pa_table): | |
| """Sanitize arrow table for JSON serialization""" | |
| import pyarrow as pa | |
| import pyarrow.compute as pc | |
| arrays = [] | |
| schema = pa_table.schema | |
| for name in schema.names: | |
| array = pa_table[name] | |
| dtype = schema.field(name).type | |
| if str(dtype).startswith("timestamp"): | |
| arrays.append(pc.strftime(array)) | |
| elif str(dtype).startswith("duration"): | |
| raise ValueError( | |
| 'Field "{col_name}" has type "{dtype}" which is ' | |
| "not supported by Altair. Please convert to " | |
| "either a timestamp or a numerical value." | |
| "".format(col_name=name, dtype=dtype) | |
| ) | |
| else: | |
| arrays.append(array) | |
| return pa.Table.from_arrays(arrays, names=schema.names) | |
| def parse_shorthand( | |
| shorthand: Union[Dict[str, Any], str], | |
| data: Optional[Union[pd.DataFrame, DataFrameLike]] = None, | |
| parse_aggregates: bool = True, | |
| parse_window_ops: bool = False, | |
| parse_timeunits: bool = True, | |
| parse_types: bool = True, | |
| ) -> Dict[str, Any]: | |
| """General tool to parse shorthand values | |
| These are of the form: | |
| - "col_name" | |
| - "col_name:O" | |
| - "average(col_name)" | |
| - "average(col_name):O" | |
| Optionally, a dataframe may be supplied, from which the type | |
| will be inferred if not specified in the shorthand. | |
| Parameters | |
| ---------- | |
| shorthand : dict or string | |
| The shorthand representation to be parsed | |
| data : DataFrame, optional | |
| If specified and of type DataFrame, then use these values to infer the | |
| column type if not provided by the shorthand. | |
| parse_aggregates : boolean | |
| If True (default), then parse aggregate functions within the shorthand. | |
| parse_window_ops : boolean | |
| If True then parse window operations within the shorthand (default:False) | |
| parse_timeunits : boolean | |
| If True (default), then parse timeUnits from within the shorthand | |
| parse_types : boolean | |
| If True (default), then parse typecodes within the shorthand | |
| Returns | |
| ------- | |
| attrs : dict | |
| a dictionary of attributes extracted from the shorthand | |
| Examples | |
| -------- | |
| >>> data = pd.DataFrame({'foo': ['A', 'B', 'A', 'B'], | |
| ... 'bar': [1, 2, 3, 4]}) | |
| >>> parse_shorthand('name') == {'field': 'name'} | |
| True | |
| >>> parse_shorthand('name:Q') == {'field': 'name', 'type': 'quantitative'} | |
| True | |
| >>> parse_shorthand('average(col)') == {'aggregate': 'average', 'field': 'col'} | |
| True | |
| >>> parse_shorthand('foo:O') == {'field': 'foo', 'type': 'ordinal'} | |
| True | |
| >>> parse_shorthand('min(foo):Q') == {'aggregate': 'min', 'field': 'foo', 'type': 'quantitative'} | |
| True | |
| >>> parse_shorthand('month(col)') == {'field': 'col', 'timeUnit': 'month', 'type': 'temporal'} | |
| True | |
| >>> parse_shorthand('year(col):O') == {'field': 'col', 'timeUnit': 'year', 'type': 'ordinal'} | |
| True | |
| >>> parse_shorthand('foo', data) == {'field': 'foo', 'type': 'nominal'} | |
| True | |
| >>> parse_shorthand('bar', data) == {'field': 'bar', 'type': 'quantitative'} | |
| True | |
| >>> parse_shorthand('bar:O', data) == {'field': 'bar', 'type': 'ordinal'} | |
| True | |
| >>> parse_shorthand('sum(bar)', data) == {'aggregate': 'sum', 'field': 'bar', 'type': 'quantitative'} | |
| True | |
| >>> parse_shorthand('count()', data) == {'aggregate': 'count', 'type': 'quantitative'} | |
| True | |
| """ | |
| from altair.utils._importers import pyarrow_available | |
| if not shorthand: | |
| return {} | |
| valid_typecodes = list(TYPECODE_MAP) + list(INV_TYPECODE_MAP) | |
| units = { | |
| "field": "(?P<field>.*)", | |
| "type": "(?P<type>{})".format("|".join(valid_typecodes)), | |
| "agg_count": "(?P<aggregate>count)", | |
| "op_count": "(?P<op>count)", | |
| "aggregate": "(?P<aggregate>{})".format("|".join(AGGREGATES)), | |
| "window_op": "(?P<op>{})".format("|".join(AGGREGATES + WINDOW_AGGREGATES)), | |
| "timeUnit": "(?P<timeUnit>{})".format("|".join(TIMEUNITS)), | |
| } | |
| patterns = [] | |
| if parse_aggregates: | |
| patterns.extend([r"{agg_count}\(\)"]) | |
| patterns.extend([r"{aggregate}\({field}\)"]) | |
| if parse_window_ops: | |
| patterns.extend([r"{op_count}\(\)"]) | |
| patterns.extend([r"{window_op}\({field}\)"]) | |
| if parse_timeunits: | |
| patterns.extend([r"{timeUnit}\({field}\)"]) | |
| patterns.extend([r"{field}"]) | |
| if parse_types: | |
| patterns = list(itertools.chain(*((p + ":{type}", p) for p in patterns))) | |
| regexps = ( | |
| re.compile(r"\A" + p.format(**units) + r"\Z", re.DOTALL) for p in patterns | |
| ) | |
| # find matches depending on valid fields passed | |
| if isinstance(shorthand, dict): | |
| attrs = shorthand | |
| else: | |
| attrs = next( | |
| exp.match(shorthand).groupdict() # type: ignore[union-attr] | |
| for exp in regexps | |
| if exp.match(shorthand) is not None | |
| ) | |
| # Handle short form of the type expression | |
| if "type" in attrs: | |
| attrs["type"] = INV_TYPECODE_MAP.get(attrs["type"], attrs["type"]) | |
| # counts are quantitative by default | |
| if attrs == {"aggregate": "count"}: | |
| attrs["type"] = "quantitative" | |
| # times are temporal by default | |
| if "timeUnit" in attrs and "type" not in attrs: | |
| attrs["type"] = "temporal" | |
| # if data is specified and type is not, infer type from data | |
| if "type" not in attrs: | |
| if pyarrow_available() and data is not None and hasattr(data, "__dataframe__"): | |
| dfi = data.__dataframe__() | |
| if "field" in attrs: | |
| unescaped_field = attrs["field"].replace("\\", "") | |
| if unescaped_field in dfi.column_names(): | |
| column = dfi.get_column_by_name(unescaped_field) | |
| try: | |
| attrs["type"] = infer_vegalite_type_for_dfi_column(column) | |
| except (NotImplementedError, AttributeError, ValueError): | |
| # Fall back to pandas-based inference. | |
| # Note: The AttributeError catch is a workaround for | |
| # https://github.com/pandas-dev/pandas/issues/55332 | |
| if isinstance(data, pd.DataFrame): | |
| attrs["type"] = infer_vegalite_type(data[unescaped_field]) | |
| else: | |
| raise | |
| if isinstance(attrs["type"], tuple): | |
| attrs["sort"] = attrs["type"][1] | |
| attrs["type"] = attrs["type"][0] | |
| elif isinstance(data, pd.DataFrame): | |
| # Fallback if pyarrow is not installed or if pandas is older than 1.5 | |
| # | |
| # Remove escape sequences so that types can be inferred for columns with special characters | |
| if "field" in attrs and attrs["field"].replace("\\", "") in data.columns: | |
| attrs["type"] = infer_vegalite_type( | |
| data[attrs["field"].replace("\\", "")] | |
| ) | |
| # ordered categorical dataframe columns return the type and sort order as a tuple | |
| if isinstance(attrs["type"], tuple): | |
| attrs["sort"] = attrs["type"][1] | |
| attrs["type"] = attrs["type"][0] | |
| # If an unescaped colon is still present, it's often due to an incorrect data type specification | |
| # but could also be due to using a column name with ":" in it. | |
| if ( | |
| "field" in attrs | |
| and ":" in attrs["field"] | |
| and attrs["field"][attrs["field"].rfind(":") - 1] != "\\" | |
| ): | |
| raise ValueError( | |
| '"{}" '.format(attrs["field"].split(":")[-1]) | |
| + "is not one of the valid encoding data types: {}.".format( | |
| ", ".join(TYPECODE_MAP.values()) | |
| ) | |
| + "\nFor more details, see https://altair-viz.github.io/user_guide/encodings/index.html#encoding-data-types. " | |
| + "If you are trying to use a column name that contains a colon, " | |
| + 'prefix it with a backslash; for example "column\\:name" instead of "column:name".' | |
| ) | |
| return attrs | |
| def infer_vegalite_type_for_dfi_column( | |
| column: Union[Column, "PandasColumn"], | |
| ) -> Union[InferredVegaLiteType, Tuple[InferredVegaLiteType, list]]: | |
| from pyarrow.interchange.from_dataframe import column_to_array | |
| try: | |
| kind = column.dtype[0] | |
| except NotImplementedError as e: | |
| # Edge case hack: | |
| # dtype access fails for pandas column with datetime64[ns, UTC] type, | |
| # but all we need to know is that its temporal, so check the | |
| # error message for the presence of datetime64. | |
| # | |
| # See https://github.com/pandas-dev/pandas/issues/54239 | |
| if "datetime64" in e.args[0] or "timestamp" in e.args[0]: | |
| return "temporal" | |
| raise e | |
| if ( | |
| kind == DtypeKind.CATEGORICAL | |
| and column.describe_categorical["is_ordered"] | |
| and column.describe_categorical["categories"] is not None | |
| ): | |
| # Treat ordered categorical column as Vega-Lite ordinal | |
| categories_column = column.describe_categorical["categories"] | |
| categories_array = column_to_array(categories_column) | |
| return "ordinal", categories_array.to_pylist() | |
| if kind in (DtypeKind.STRING, DtypeKind.CATEGORICAL, DtypeKind.BOOL): | |
| return "nominal" | |
| elif kind in (DtypeKind.INT, DtypeKind.UINT, DtypeKind.FLOAT): | |
| return "quantitative" | |
| elif kind == DtypeKind.DATETIME: | |
| return "temporal" | |
| else: | |
| raise ValueError(f"Unexpected DtypeKind: {kind}") | |
| def use_signature(Obj: Callable[P, Any]): | |
| """Apply call signature and documentation of Obj to the decorated method""" | |
| def decorate(f: Callable[..., V]) -> Callable[P, V]: | |
| # call-signature of f is exposed via __wrapped__. | |
| # we want it to mimic Obj.__init__ | |
| f.__wrapped__ = Obj.__init__ # type: ignore | |
| f._uses_signature = Obj # type: ignore | |
| # Supplement the docstring of f with information from Obj | |
| if Obj.__doc__: | |
| # Patch in a reference to the class this docstring is copied from, | |
| # to generate a hyperlink. | |
| doclines = Obj.__doc__.splitlines() | |
| doclines[0] = f"Refer to :class:`{Obj.__name__}`" | |
| if f.__doc__: | |
| doc = f.__doc__ + "\n".join(doclines[1:]) | |
| else: | |
| doc = "\n".join(doclines) | |
| try: | |
| f.__doc__ = doc | |
| except AttributeError: | |
| # __doc__ is not modifiable for classes in Python < 3.3 | |
| pass | |
| return f | |
| return decorate | |
| def update_nested( | |
| original: MutableMapping, update: Mapping, copy: bool = False | |
| ) -> MutableMapping: | |
| """Update nested dictionaries | |
| Parameters | |
| ---------- | |
| original : MutableMapping | |
| the original (nested) dictionary, which will be updated in-place | |
| update : Mapping | |
| the nested dictionary of updates | |
| copy : bool, default False | |
| if True, then copy the original dictionary rather than modifying it | |
| Returns | |
| ------- | |
| original : MutableMapping | |
| a reference to the (modified) original dict | |
| Examples | |
| -------- | |
| >>> original = {'x': {'b': 2, 'c': 4}} | |
| >>> update = {'x': {'b': 5, 'd': 6}, 'y': 40} | |
| >>> update_nested(original, update) # doctest: +SKIP | |
| {'x': {'b': 5, 'c': 4, 'd': 6}, 'y': 40} | |
| >>> original # doctest: +SKIP | |
| {'x': {'b': 5, 'c': 4, 'd': 6}, 'y': 40} | |
| """ | |
| if copy: | |
| original = deepcopy(original) | |
| for key, val in update.items(): | |
| if isinstance(val, Mapping): | |
| orig_val = original.get(key, {}) | |
| if isinstance(orig_val, MutableMapping): | |
| original[key] = update_nested(orig_val, val) | |
| else: | |
| original[key] = val | |
| else: | |
| original[key] = val | |
| return original | |
| def display_traceback(in_ipython: bool = True): | |
| exc_info = sys.exc_info() | |
| if in_ipython: | |
| from IPython.core.getipython import get_ipython | |
| ip = get_ipython() | |
| else: | |
| ip = None | |
| if ip is not None: | |
| ip.showtraceback(exc_info) | |
| else: | |
| traceback.print_exception(*exc_info) | |
| def infer_encoding_types(args: Sequence, kwargs: MutableMapping, channels: ModuleType): | |
| """Infer typed keyword arguments for args and kwargs | |
| Parameters | |
| ---------- | |
| args : Sequence | |
| Sequence of function args | |
| kwargs : MutableMapping | |
| Dict of function kwargs | |
| channels : ModuleType | |
| The module containing all altair encoding channel classes. | |
| Returns | |
| ------- | |
| kwargs : dict | |
| All args and kwargs in a single dict, with keys and types | |
| based on the channels mapping. | |
| """ | |
| # Construct a dictionary of channel type to encoding name | |
| # TODO: cache this somehow? | |
| channel_objs = (getattr(channels, name) for name in dir(channels)) | |
| channel_objs = ( | |
| c for c in channel_objs if isinstance(c, type) and issubclass(c, SchemaBase) | |
| ) | |
| channel_to_name: Dict[Type[SchemaBase], str] = { | |
| c: c._encoding_name for c in channel_objs | |
| } | |
| name_to_channel: Dict[str, Dict[str, Type[SchemaBase]]] = {} | |
| for chan, name in channel_to_name.items(): | |
| chans = name_to_channel.setdefault(name, {}) | |
| if chan.__name__.endswith("Datum"): | |
| key = "datum" | |
| elif chan.__name__.endswith("Value"): | |
| key = "value" | |
| else: | |
| key = "field" | |
| chans[key] = chan | |
| # First use the mapping to convert args to kwargs based on their types. | |
| for arg in args: | |
| if isinstance(arg, (list, tuple)) and len(arg) > 0: | |
| type_ = type(arg[0]) | |
| else: | |
| type_ = type(arg) | |
| encoding = channel_to_name.get(type_, None) | |
| if encoding is None: | |
| raise NotImplementedError("positional of type {}" "".format(type_)) | |
| if encoding in kwargs: | |
| raise ValueError("encoding {} specified twice.".format(encoding)) | |
| kwargs[encoding] = arg | |
| def _wrap_in_channel_class(obj, encoding): | |
| if isinstance(obj, SchemaBase): | |
| return obj | |
| if isinstance(obj, str): | |
| obj = {"shorthand": obj} | |
| if isinstance(obj, (list, tuple)): | |
| return [_wrap_in_channel_class(subobj, encoding) for subobj in obj] | |
| if encoding not in name_to_channel: | |
| warnings.warn( | |
| "Unrecognized encoding channel '{}'".format(encoding), stacklevel=1 | |
| ) | |
| return obj | |
| classes = name_to_channel[encoding] | |
| cls = classes["value"] if "value" in obj else classes["field"] | |
| try: | |
| # Don't force validation here; some objects won't be valid until | |
| # they're created in the context of a chart. | |
| return cls.from_dict(obj, validate=False) | |
| except jsonschema.ValidationError: | |
| # our attempts at finding the correct class have failed | |
| return obj | |
| return { | |
| encoding: _wrap_in_channel_class(obj, encoding) | |
| for encoding, obj in kwargs.items() | |
| } | |