File size: 5,490 Bytes
7aa5a5e 058c80a 3c36ff5 6502654 3c36ff5 d08fbc6 a4795aa d08fbc6 058c80a 3c36ff5 a4795aa 3c36ff5 a4795aa 3c36ff5 d08fbc6 3c36ff5 d08fbc6 3c36ff5 d08fbc6 058c80a d08fbc6 058c80a d08fbc6 058c80a d08fbc6 058c80a d08fbc6 058c80a 3c36ff5 7aa5a5e 058c80a d08fbc6 058c80a 7aa5a5e d08fbc6 7aa5a5e d08fbc6 7aa5a5e d08fbc6 7aa5a5e 058c80a d08fbc6 058c80a d08fbc6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
from functools import lru_cache
from typing import Any, Dict, List, Optional, Union
from datasets import DatasetDict
from .artifact import fetch_artifact
from .dataset_utils import get_dataset_artifact
from .logging_utils import get_logger
from .metric_utils import _compute, _inference_post_process
from .operator import SourceOperator
from .schema import UNITXT_DATASET_SCHEMA
from .standard import StandardRecipe
logger = get_logger()
def load(source: Union[SourceOperator, str]) -> DatasetDict:
assert isinstance(
source, (SourceOperator, str)
), "source must be a SourceOperator or a string"
if isinstance(source, str):
source, _ = fetch_artifact(source)
return source().to_dataset()
def _get_recipe_from_query(dataset_query: str) -> StandardRecipe:
dataset_query = dataset_query.replace("sys_prompt", "instruction")
try:
dataset_stream, _ = fetch_artifact(dataset_query)
except:
dataset_stream = get_dataset_artifact(dataset_query)
return dataset_stream
def _get_recipe_from_dict(dataset_params: Dict[str, Any]) -> StandardRecipe:
recipe_attributes = list(StandardRecipe.__dict__["__fields__"].keys())
for param in dataset_params.keys():
assert param in recipe_attributes, (
f"The parameter '{param}' is not an attribute of the 'StandardRecipe' class. "
f"Please check if the name is correct. The available attributes are: '{recipe_attributes}'."
)
return StandardRecipe(**dataset_params)
def _verify_dataset_args(dataset_query: Optional[str] = None, dataset_args=None):
if dataset_query and dataset_args:
raise ValueError(
"Cannot provide 'dataset_query' and key-worded arguments at the same time. "
"If you want to load dataset from a card in local catalog, use query only. "
"Otherwise, use key-worded arguments only to specify properties of dataset."
)
if dataset_query:
if not isinstance(dataset_query, str):
raise ValueError(
f"If specified, 'dataset_query' must be a string, however, "
f"'{dataset_query}' was provided instead, which is of type "
f"'{type(dataset_query)}'."
)
if not dataset_query and not dataset_args:
raise ValueError(
"Either 'dataset_query' or key-worded arguments must be provided."
)
def load_recipe(dataset_query: Optional[str] = None, **kwargs) -> StandardRecipe:
if isinstance(dataset_query, StandardRecipe):
return dataset_query
_verify_dataset_args(dataset_query, kwargs)
if dataset_query:
recipe = _get_recipe_from_query(dataset_query)
if kwargs:
recipe = _get_recipe_from_dict(kwargs)
return recipe
def load_dataset(dataset_query: Optional[str] = None, **kwargs) -> DatasetDict:
"""Loads dataset.
If the 'dataset_query' argument is provided, then dataset is loaded from a card in local
catalog based on parameters specified in the query.
Alternatively, dataset is loaded from a provided card based on explicitly given parameters.
Args:
dataset_query (str, optional): A string query which specifies a dataset to load from local catalog or name of specific recipe or benchmark in the catalog.
For example:
"card=cards.wnli,template=templates.classification.multi_class.relation.default".
**kwargs: Arguments used to load dataset from provided card, which is not present in local catalog.
Returns:
DatasetDict
Examples:
dataset = load_dataset(
dataset_query="card=cards.stsb,template=templates.regression.two_texts.simple,max_train_instances=5"
) # card must be present in local catalog
card = TaskCard(...)
template = Template(...)
loader_limit = 10
dataset = load_dataset(card=card, template=template, loader_limit=loader_limit)
"""
recipe = load_recipe(dataset_query, **kwargs)
return recipe().to_dataset(features=UNITXT_DATASET_SCHEMA)
def evaluate(predictions, data) -> List[Dict[str, Any]]:
return _compute(predictions=predictions, references=data)
def post_process(predictions, data) -> List[Dict[str, Any]]:
return _inference_post_process(predictions=predictions, references=data)
@lru_cache
def _get_produce_with_cache(dataset_query: Optional[str] = None, **kwargs):
return load_recipe(dataset_query, **kwargs).produce
def produce(instance_or_instances, dataset_query: Optional[str] = None, **kwargs):
is_list = isinstance(instance_or_instances, list)
if not is_list:
instance_or_instances = [instance_or_instances]
result = _get_produce_with_cache(dataset_query, **kwargs)(instance_or_instances)
if not is_list:
result = result[0]
return result
def infer(
instance_or_instances,
engine,
dataset_query: Optional[str] = None,
return_data=False,
**kwargs,
):
dataset = produce(instance_or_instances, dataset_query, **kwargs)
engine, _ = fetch_artifact(engine)
raw_predictions = engine.infer(dataset)
predictions = post_process(raw_predictions, dataset)
if return_data:
for prediction, raw_prediction, instance in zip(
predictions, raw_predictions, dataset
):
instance["prediction"] = prediction
instance["raw_prediction"] = raw_prediction
return dataset
return predictions
|