Spaces:
No application file
No application file
| from typing import List, Optional, Tuple, Dict, Iterable, overload, Union | |
| from altair import ( | |
| Chart, | |
| FacetChart, | |
| LayerChart, | |
| HConcatChart, | |
| VConcatChart, | |
| ConcatChart, | |
| TopLevelUnitSpec, | |
| FacetedUnitSpec, | |
| UnitSpec, | |
| UnitSpecWithFrame, | |
| NonNormalizedSpec, | |
| TopLevelLayerSpec, | |
| LayerSpec, | |
| TopLevelConcatSpec, | |
| ConcatSpecGenericSpec, | |
| TopLevelHConcatSpec, | |
| HConcatSpecGenericSpec, | |
| TopLevelVConcatSpec, | |
| VConcatSpecGenericSpec, | |
| TopLevelFacetSpec, | |
| FacetSpec, | |
| data_transformers, | |
| ) | |
| from altair.utils._vegafusion_data import get_inline_tables, import_vegafusion | |
| from altair.utils.core import DataFrameLike | |
| from altair.utils.schemapi import Undefined | |
| Scope = Tuple[int, ...] | |
| FacetMapping = Dict[Tuple[str, Scope], Tuple[str, Scope]] | |
| # For the transformed_data functionality, the chart classes in the values | |
| # can be considered equivalent to the chart class in the key. | |
| _chart_class_mapping = { | |
| Chart: ( | |
| Chart, | |
| TopLevelUnitSpec, | |
| FacetedUnitSpec, | |
| UnitSpec, | |
| UnitSpecWithFrame, | |
| NonNormalizedSpec, | |
| ), | |
| LayerChart: (LayerChart, TopLevelLayerSpec, LayerSpec), | |
| ConcatChart: (ConcatChart, TopLevelConcatSpec, ConcatSpecGenericSpec), | |
| HConcatChart: (HConcatChart, TopLevelHConcatSpec, HConcatSpecGenericSpec), | |
| VConcatChart: (VConcatChart, TopLevelVConcatSpec, VConcatSpecGenericSpec), | |
| FacetChart: (FacetChart, TopLevelFacetSpec, FacetSpec), | |
| } | |
| def transformed_data( | |
| chart: Union[Chart, FacetChart], | |
| row_limit: Optional[int] = None, | |
| exclude: Optional[Iterable[str]] = None, | |
| ) -> Optional[DataFrameLike]: | |
| ... | |
| def transformed_data( | |
| chart: Union[LayerChart, HConcatChart, VConcatChart, ConcatChart], | |
| row_limit: Optional[int] = None, | |
| exclude: Optional[Iterable[str]] = None, | |
| ) -> List[DataFrameLike]: | |
| ... | |
| def transformed_data(chart, row_limit=None, exclude=None): | |
| """Evaluate a Chart's transforms | |
| Evaluate the data transforms associated with a Chart and return the | |
| transformed data as one or more DataFrames | |
| Parameters | |
| ---------- | |
| chart : Chart, FacetChart, LayerChart, HConcatChart, VConcatChart, or ConcatChart | |
| Altair chart to evaluate transforms on | |
| row_limit : int (optional) | |
| Maximum number of rows to return for each DataFrame. None (default) for unlimited | |
| exclude : iterable of str | |
| Set of the names of charts to exclude | |
| Returns | |
| ------- | |
| DataFrame or list of DataFrames or None | |
| If input chart is a Chart or Facet Chart, returns a DataFrame of the | |
| transformed data. Otherwise, returns a list of DataFrames of the | |
| transformed data | |
| """ | |
| vf = import_vegafusion() | |
| if isinstance(chart, Chart): | |
| # Add mark if none is specified to satisfy Vega-Lite | |
| if chart.mark == Undefined: | |
| chart = chart.mark_point() | |
| # Deep copy chart so that we can rename marks without affecting caller | |
| chart = chart.copy(deep=True) | |
| # Ensure that all views are named so that we can look them up in the | |
| # resulting Vega specification | |
| chart_names = name_views(chart, 0, exclude=exclude) | |
| # Compile to Vega and extract inline DataFrames | |
| with data_transformers.enable("vegafusion"): | |
| vega_spec = chart.to_dict(format="vega", context={"pre_transform": False}) | |
| inline_datasets = get_inline_tables(vega_spec) | |
| # Build mapping from mark names to vega datasets | |
| facet_mapping = get_facet_mapping(vega_spec) | |
| dataset_mapping = get_datasets_for_view_names(vega_spec, chart_names, facet_mapping) | |
| # Build a list of vega dataset names that corresponds to the order | |
| # of the chart components | |
| dataset_names = [] | |
| for chart_name in chart_names: | |
| if chart_name in dataset_mapping: | |
| dataset_names.append(dataset_mapping[chart_name]) | |
| else: | |
| raise ValueError("Failed to locate all datasets") | |
| # Extract transformed datasets with VegaFusion | |
| datasets, warnings = vf.runtime.pre_transform_datasets( | |
| vega_spec, | |
| dataset_names, | |
| row_limit=row_limit, | |
| inline_datasets=inline_datasets, | |
| ) | |
| if isinstance(chart, (Chart, FacetChart)): | |
| # Return DataFrame (or None if it was excluded) if input was a simple Chart | |
| if not datasets: | |
| return None | |
| else: | |
| return datasets[0] | |
| else: | |
| # Otherwise return the list of DataFrames | |
| return datasets | |
| # The equivalent classes from _chart_class_mapping should also be added | |
| # to the type hints below for `chart` as the function would also work for them. | |
| # However, this was not possible so far as mypy then complains about | |
| # "Overloaded function signatures 1 and 2 overlap with incompatible return types [misc]" | |
| # This might be due to the complex type hierarchy of the chart classes. | |
| # See also https://github.com/python/mypy/issues/5119 | |
| # and https://github.com/python/mypy/issues/4020 which show that mypy might not have | |
| # a very consistent behavior for overloaded functions. | |
| # The same error appeared when trying it with Protocols for the concat and layer charts. | |
| # This function is only used internally and so we accept this inconsistency for now. | |
| def name_views( | |
| chart: Union[ | |
| Chart, FacetChart, LayerChart, HConcatChart, VConcatChart, ConcatChart | |
| ], | |
| i: int = 0, | |
| exclude: Optional[Iterable[str]] = None, | |
| ) -> List[str]: | |
| """Name unnamed chart views | |
| Name unnamed charts views so that we can look them up later in | |
| the compiled Vega spec. | |
| Note: This function mutates the input chart by applying names to | |
| unnamed views. | |
| Parameters | |
| ---------- | |
| chart : Chart, FacetChart, LayerChart, HConcatChart, VConcatChart, or ConcatChart | |
| Altair chart to apply names to | |
| i : int (default 0) | |
| Starting chart index | |
| exclude : iterable of str | |
| Names of charts to exclude | |
| Returns | |
| ------- | |
| list of str | |
| List of the names of the charts and subcharts | |
| """ | |
| exclude = set(exclude) if exclude is not None else set() | |
| if isinstance(chart, _chart_class_mapping[Chart]) or isinstance( | |
| chart, _chart_class_mapping[FacetChart] | |
| ): | |
| if chart.name not in exclude: | |
| if chart.name in (None, Undefined): | |
| # Add name since none is specified | |
| chart.name = Chart._get_name() | |
| return [chart.name] | |
| else: | |
| return [] | |
| else: | |
| if isinstance(chart, _chart_class_mapping[LayerChart]): | |
| subcharts = chart.layer | |
| elif isinstance(chart, _chart_class_mapping[HConcatChart]): | |
| subcharts = chart.hconcat | |
| elif isinstance(chart, _chart_class_mapping[VConcatChart]): | |
| subcharts = chart.vconcat | |
| elif isinstance(chart, _chart_class_mapping[ConcatChart]): | |
| subcharts = chart.concat | |
| else: | |
| raise ValueError( | |
| "transformed_data accepts an instance of " | |
| "Chart, FacetChart, LayerChart, HConcatChart, VConcatChart, or ConcatChart\n" | |
| f"Received value of type: {type(chart)}" | |
| ) | |
| chart_names: List[str] = [] | |
| for subchart in subcharts: | |
| for name in name_views(subchart, i=i + len(chart_names), exclude=exclude): | |
| chart_names.append(name) | |
| return chart_names | |
| def get_group_mark_for_scope(vega_spec: dict, scope: Scope) -> Optional[dict]: | |
| """Get the group mark at a particular scope | |
| Parameters | |
| ---------- | |
| vega_spec : dict | |
| Top-level Vega specification dictionary | |
| scope : tuple of int | |
| Scope tuple. If empty, the original Vega specification is returned. | |
| Otherwise, the nested group mark at the scope specified is returned. | |
| Returns | |
| ------- | |
| dict or None | |
| Top-level Vega spec (if scope is empty) | |
| or group mark (if scope is non-empty) | |
| or None (if group mark at scope does not exist) | |
| Examples | |
| -------- | |
| >>> spec = { | |
| ... "marks": [ | |
| ... { | |
| ... "type": "group", | |
| ... "marks": [{"type": "symbol"}] | |
| ... }, | |
| ... { | |
| ... "type": "group", | |
| ... "marks": [{"type": "rect"}]} | |
| ... ] | |
| ... } | |
| >>> get_group_mark_for_scope(spec, (1,)) | |
| {'type': 'group', 'marks': [{'type': 'rect'}]} | |
| """ | |
| group = vega_spec | |
| # Find group at scope | |
| for scope_value in scope: | |
| group_index = 0 | |
| child_group = None | |
| for mark in group.get("marks", []): | |
| if mark.get("type") == "group": | |
| if group_index == scope_value: | |
| child_group = mark | |
| break | |
| group_index += 1 | |
| if child_group is None: | |
| return None | |
| group = child_group | |
| return group | |
| def get_datasets_for_scope(vega_spec: dict, scope: Scope) -> List[str]: | |
| """Get the names of the datasets that are defined at a given scope | |
| Parameters | |
| ---------- | |
| vega_spec : dict | |
| Top-leve Vega specification | |
| scope : tuple of int | |
| Scope tuple. If empty, the names of top-level datasets are returned | |
| Otherwise, the names of the datasets defined in the nested group mark | |
| at the specified scope are returned. | |
| Returns | |
| ------- | |
| list of str | |
| List of the names of the datasets defined at the specified scope | |
| Examples | |
| -------- | |
| >>> spec = { | |
| ... "data": [ | |
| ... {"name": "data1"} | |
| ... ], | |
| ... "marks": [ | |
| ... { | |
| ... "type": "group", | |
| ... "data": [ | |
| ... {"name": "data2"} | |
| ... ], | |
| ... "marks": [{"type": "symbol"}] | |
| ... }, | |
| ... { | |
| ... "type": "group", | |
| ... "data": [ | |
| ... {"name": "data3"}, | |
| ... {"name": "data4"}, | |
| ... ], | |
| ... "marks": [{"type": "rect"}] | |
| ... } | |
| ... ] | |
| ... } | |
| >>> get_datasets_for_scope(spec, ()) | |
| ['data1'] | |
| >>> get_datasets_for_scope(spec, (0,)) | |
| ['data2'] | |
| >>> get_datasets_for_scope(spec, (1,)) | |
| ['data3', 'data4'] | |
| Returns empty when no group mark exists at scope | |
| >>> get_datasets_for_scope(spec, (1, 3)) | |
| [] | |
| """ | |
| group = get_group_mark_for_scope(vega_spec, scope) or {} | |
| # get datasets from group | |
| datasets = [] | |
| for dataset in group.get("data", []): | |
| datasets.append(dataset["name"]) | |
| # Add facet dataset | |
| facet_dataset = group.get("from", {}).get("facet", {}).get("name", None) | |
| if facet_dataset: | |
| datasets.append(facet_dataset) | |
| return datasets | |
| def get_definition_scope_for_data_reference( | |
| vega_spec: dict, data_name: str, usage_scope: Scope | |
| ) -> Optional[Scope]: | |
| """Return the scope that a dataset is defined at, for a given usage scope | |
| Parameters | |
| ---------- | |
| vega_spec: dict | |
| Top-level Vega specification | |
| data_name: str | |
| The name of a dataset reference | |
| usage_scope: tuple of int | |
| The scope that the dataset is referenced in | |
| Returns | |
| ------- | |
| tuple of int | |
| The scope where the referenced dataset is defined, | |
| or None if no such dataset is found | |
| Examples | |
| -------- | |
| >>> spec = { | |
| ... "data": [ | |
| ... {"name": "data1"} | |
| ... ], | |
| ... "marks": [ | |
| ... { | |
| ... "type": "group", | |
| ... "data": [ | |
| ... {"name": "data2"} | |
| ... ], | |
| ... "marks": [{ | |
| ... "type": "symbol", | |
| ... "encode": { | |
| ... "update": { | |
| ... "x": {"field": "x", "data": "data1"}, | |
| ... "y": {"field": "y", "data": "data2"}, | |
| ... } | |
| ... } | |
| ... }] | |
| ... } | |
| ... ] | |
| ... } | |
| data1 is referenced at scope [0] and defined at scope [] | |
| >>> get_definition_scope_for_data_reference(spec, "data1", (0,)) | |
| () | |
| data2 is referenced at scope [0] and defined at scope [0] | |
| >>> get_definition_scope_for_data_reference(spec, "data2", (0,)) | |
| (0,) | |
| If data2 is not visible at scope [] (the top level), | |
| because it's defined in scope [0] | |
| >>> repr(get_definition_scope_for_data_reference(spec, "data2", ())) | |
| 'None' | |
| """ | |
| for i in reversed(range(len(usage_scope) + 1)): | |
| scope = usage_scope[:i] | |
| datasets = get_datasets_for_scope(vega_spec, scope) | |
| if data_name in datasets: | |
| return scope | |
| return None | |
| def get_facet_mapping(group: dict, scope: Scope = ()) -> FacetMapping: | |
| """Create mapping from facet definitions to source datasets | |
| Parameters | |
| ---------- | |
| group : dict | |
| Top-level Vega spec or nested group mark | |
| scope : tuple of int | |
| Scope of the group dictionary within a top-level Vega spec | |
| Returns | |
| ------- | |
| dict | |
| Dictionary from (facet_name, facet_scope) to (dataset_name, dataset_scope) | |
| Examples | |
| -------- | |
| >>> spec = { | |
| ... "data": [ | |
| ... {"name": "data1"} | |
| ... ], | |
| ... "marks": [ | |
| ... { | |
| ... "type": "group", | |
| ... "from": { | |
| ... "facet": { | |
| ... "name": "facet1", | |
| ... "data": "data1", | |
| ... "groupby": ["colA"] | |
| ... } | |
| ... } | |
| ... } | |
| ... ] | |
| ... } | |
| >>> get_facet_mapping(spec) | |
| {('facet1', (0,)): ('data1', ())} | |
| """ | |
| facet_mapping = {} | |
| group_index = 0 | |
| mark_group = get_group_mark_for_scope(group, scope) or {} | |
| for mark in mark_group.get("marks", []): | |
| if mark.get("type", None) == "group": | |
| # Get facet for this group | |
| group_scope = scope + (group_index,) | |
| facet = mark.get("from", {}).get("facet", None) | |
| if facet is not None: | |
| facet_name = facet.get("name", None) | |
| facet_data = facet.get("data", None) | |
| if facet_name is not None and facet_data is not None: | |
| definition_scope = get_definition_scope_for_data_reference( | |
| group, facet_data, scope | |
| ) | |
| if definition_scope is not None: | |
| facet_mapping[(facet_name, group_scope)] = ( | |
| facet_data, | |
| definition_scope, | |
| ) | |
| # Handle children recursively | |
| child_mapping = get_facet_mapping(group, scope=group_scope) | |
| facet_mapping.update(child_mapping) | |
| group_index += 1 | |
| return facet_mapping | |
| def get_from_facet_mapping( | |
| scoped_dataset: Tuple[str, Scope], facet_mapping: FacetMapping | |
| ) -> Tuple[str, Scope]: | |
| """Apply facet mapping to a scoped dataset | |
| Parameters | |
| ---------- | |
| scoped_dataset : (str, tuple of int) | |
| A dataset name and scope tuple | |
| facet_mapping : dict from (str, tuple of int) to (str, tuple of int) | |
| The facet mapping produced by get_facet_mapping | |
| Returns | |
| ------- | |
| (str, tuple of int) | |
| Dataset name and scope tuple that has been mapped as many times as possible | |
| Examples | |
| -------- | |
| Facet mapping as produced by get_facet_mapping | |
| >>> facet_mapping = {("facet1", (0,)): ("data1", ()), ("facet2", (0, 1)): ("facet1", (0,))} | |
| >>> get_from_facet_mapping(("facet2", (0, 1)), facet_mapping) | |
| ('data1', ()) | |
| """ | |
| while scoped_dataset in facet_mapping: | |
| scoped_dataset = facet_mapping[scoped_dataset] | |
| return scoped_dataset | |
| def get_datasets_for_view_names( | |
| group: dict, | |
| vl_chart_names: List[str], | |
| facet_mapping: FacetMapping, | |
| scope: Scope = (), | |
| ) -> Dict[str, Tuple[str, Scope]]: | |
| """Get the Vega datasets that correspond to the provided Altair view names | |
| Parameters | |
| ---------- | |
| group : dict | |
| Top-level Vega spec or nested group mark | |
| vl_chart_names : list of str | |
| List of the Vega-Lite | |
| facet_mapping : dict from (str, tuple of int) to (str, tuple of int) | |
| The facet mapping produced by get_facet_mapping | |
| scope : tuple of int | |
| Scope of the group dictionary within a top-level Vega spec | |
| Returns | |
| ------- | |
| dict from str to (str, tuple of int) | |
| Dict from Altair view names to scoped datasets | |
| """ | |
| datasets = {} | |
| group_index = 0 | |
| mark_group = get_group_mark_for_scope(group, scope) or {} | |
| for mark in mark_group.get("marks", []): | |
| for vl_chart_name in vl_chart_names: | |
| if mark.get("name", "") == f"{vl_chart_name}_cell": | |
| data_name = mark.get("from", {}).get("facet", None).get("data", None) | |
| scoped_data_name = (data_name, scope) | |
| datasets[vl_chart_name] = get_from_facet_mapping( | |
| scoped_data_name, facet_mapping | |
| ) | |
| break | |
| name = mark.get("name", "") | |
| if mark.get("type", "") == "group": | |
| group_data_names = get_datasets_for_view_names( | |
| group, vl_chart_names, facet_mapping, scope=scope + (group_index,) | |
| ) | |
| for k, v in group_data_names.items(): | |
| datasets.setdefault(k, v) | |
| group_index += 1 | |
| else: | |
| for vl_chart_name in vl_chart_names: | |
| if name.startswith(vl_chart_name) and name.endswith("_marks"): | |
| data_name = mark.get("from", {}).get("data", None) | |
| scoped_data = get_definition_scope_for_data_reference( | |
| group, data_name, scope | |
| ) | |
| if scoped_data is not None: | |
| datasets[vl_chart_name] = get_from_facet_mapping( | |
| (data_name, scoped_data), facet_mapping | |
| ) | |
| break | |
| return datasets | |