Spaces:
Runtime error
Runtime error
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" | |
Helpful utility functions and classes in relation to exploring API endpoints | |
with the aim for a user-friendly interface. | |
""" | |
import math | |
import re | |
from dataclasses import dataclass | |
from typing import TYPE_CHECKING, Iterable, List, Optional, Union | |
if TYPE_CHECKING: | |
from ..hf_api import ModelInfo | |
def _filter_emissions( | |
models: Iterable["ModelInfo"], | |
minimum_threshold: Optional[float] = None, | |
maximum_threshold: Optional[float] = None, | |
) -> Iterable["ModelInfo"]: | |
"""Filters a list of models for those that include an emission tag and limit them to between two thresholds | |
Args: | |
models (Iterable of `ModelInfo`): | |
A list of models to filter. | |
minimum_threshold (`float`, *optional*): | |
A minimum carbon threshold to filter by, such as 1. | |
maximum_threshold (`float`, *optional*): | |
A maximum carbon threshold to filter by, such as 10. | |
""" | |
if minimum_threshold is None and maximum_threshold is None: | |
raise ValueError("Both `minimum_threshold` and `maximum_threshold` cannot both be `None`") | |
if minimum_threshold is None: | |
minimum_threshold = -1 | |
if maximum_threshold is None: | |
maximum_threshold = math.inf | |
for model in models: | |
card_data = getattr(model, "cardData", None) | |
if card_data is None or not isinstance(card_data, dict): | |
continue | |
# Get CO2 emission metadata | |
emission = card_data.get("co2_eq_emissions", None) | |
if isinstance(emission, dict): | |
emission = emission["emissions"] | |
if not emission: | |
continue | |
# Filter out if value is missing or out of range | |
matched = re.search(r"\d+\.\d+|\d+", str(emission)) | |
if matched is None: | |
continue | |
emission_value = float(matched.group(0)) | |
if emission_value >= minimum_threshold and emission_value <= maximum_threshold: | |
yield model | |
class DatasetFilter: | |
""" | |
A class that converts human-readable dataset search parameters into ones | |
compatible with the REST API. For all parameters capitalization does not | |
matter. | |
Args: | |
author (`str`, *optional*): | |
A string or list of strings that can be used to identify datasets on | |
the Hub by the original uploader (author or organization), such as | |
`facebook` or `huggingface`. | |
benchmark (`str` or `List`, *optional*): | |
A string or list of strings that can be used to identify datasets on | |
the Hub by their official benchmark. | |
dataset_name (`str`, *optional*): | |
A string or list of strings that can be used to identify datasets on | |
the Hub by its name, such as `SQAC` or `wikineural` | |
language_creators (`str` or `List`, *optional*): | |
A string or list of strings that can be used to identify datasets on | |
the Hub with how the data was curated, such as `crowdsourced` or | |
`machine_generated`. | |
language (`str` or `List`, *optional*): | |
A string or list of strings representing a two-character language to | |
filter datasets by on the Hub. | |
multilinguality (`str` or `List`, *optional*): | |
A string or list of strings representing a filter for datasets that | |
contain multiple languages. | |
size_categories (`str` or `List`, *optional*): | |
A string or list of strings that can be used to identify datasets on | |
the Hub by the size of the dataset such as `100K<n<1M` or | |
`1M<n<10M`. | |
task_categories (`str` or `List`, *optional*): | |
A string or list of strings that can be used to identify datasets on | |
the Hub by the designed task, such as `audio_classification` or | |
`named_entity_recognition`. | |
task_ids (`str` or `List`, *optional*): | |
A string or list of strings that can be used to identify datasets on | |
the Hub by the specific task such as `speech_emotion_recognition` or | |
`paraphrase`. | |
Examples: | |
```py | |
>>> from huggingface_hub import DatasetFilter | |
>>> # Using author | |
>>> new_filter = DatasetFilter(author="facebook") | |
>>> # Using benchmark | |
>>> new_filter = DatasetFilter(benchmark="raft") | |
>>> # Using dataset_name | |
>>> new_filter = DatasetFilter(dataset_name="wikineural") | |
>>> # Using language_creator | |
>>> new_filter = DatasetFilter(language_creator="crowdsourced") | |
>>> # Using language | |
>>> new_filter = DatasetFilter(language="en") | |
>>> # Using multilinguality | |
>>> new_filter = DatasetFilter(multilinguality="multilingual") | |
>>> # Using size_categories | |
>>> new_filter = DatasetFilter(size_categories="100K<n<1M") | |
>>> # Using task_categories | |
>>> new_filter = DatasetFilter(task_categories="audio_classification") | |
>>> # Using task_ids | |
>>> new_filter = DatasetFilter(task_ids="paraphrase") | |
``` | |
""" | |
author: Optional[str] = None | |
benchmark: Optional[Union[str, List[str]]] = None | |
dataset_name: Optional[str] = None | |
language_creators: Optional[Union[str, List[str]]] = None | |
language: Optional[Union[str, List[str]]] = None | |
multilinguality: Optional[Union[str, List[str]]] = None | |
size_categories: Optional[Union[str, List[str]]] = None | |
task_categories: Optional[Union[str, List[str]]] = None | |
task_ids: Optional[Union[str, List[str]]] = None | |
class ModelFilter: | |
""" | |
A class that converts human-readable model search parameters into ones | |
compatible with the REST API. For all parameters capitalization does not | |
matter. | |
Args: | |
author (`str`, *optional*): | |
A string that can be used to identify models on the Hub by the | |
original uploader (author or organization), such as `facebook` or | |
`huggingface`. | |
library (`str` or `List`, *optional*): | |
A string or list of strings of foundational libraries models were | |
originally trained from, such as pytorch, tensorflow, or allennlp. | |
language (`str` or `List`, *optional*): | |
A string or list of strings of languages, both by name and country | |
code, such as "en" or "English" | |
model_name (`str`, *optional*): | |
A string that contain complete or partial names for models on the | |
Hub, such as "bert" or "bert-base-cased" | |
task (`str` or `List`, *optional*): | |
A string or list of strings of tasks models were designed for, such | |
as: "fill-mask" or "automatic-speech-recognition" | |
tags (`str` or `List`, *optional*): | |
A string tag or a list of tags to filter models on the Hub by, such | |
as `text-generation` or `spacy`. | |
trained_dataset (`str` or `List`, *optional*): | |
A string tag or a list of string tags of the trained dataset for a | |
model on the Hub. | |
```python | |
>>> from huggingface_hub import ModelFilter | |
>>> # For the author_or_organization | |
>>> new_filter = ModelFilter(author_or_organization="facebook") | |
>>> # For the library | |
>>> new_filter = ModelFilter(library="pytorch") | |
>>> # For the language | |
>>> new_filter = ModelFilter(language="french") | |
>>> # For the model_name | |
>>> new_filter = ModelFilter(model_name="bert") | |
>>> # For the task | |
>>> new_filter = ModelFilter(task="text-classification") | |
>>> # Retrieving tags using the `HfApi.get_model_tags` method | |
>>> from huggingface_hub import HfApi | |
>>> api = HfApi() | |
# To list model tags | |
>>> api.get_model_tags() | |
# To list dataset tags | |
>>> api.get_dataset_tags() | |
>>> new_filter = ModelFilter(tags="benchmark:raft") | |
>>> # Related to the dataset | |
>>> new_filter = ModelFilter(trained_dataset="common_voice") | |
``` | |
""" | |
author: Optional[str] = None | |
library: Optional[Union[str, List[str]]] = None | |
language: Optional[Union[str, List[str]]] = None | |
model_name: Optional[str] = None | |
task: Optional[Union[str, List[str]]] = None | |
trained_dataset: Optional[Union[str, List[str]]] = None | |
tags: Optional[Union[str, List[str]]] = None | |
class AttributeDictionary(dict): | |
""" | |
`dict` subclass that also provides access to keys as attributes | |
If a key starts with a number, it will exist in the dictionary but not as an | |
attribute | |
Example: | |
```python | |
>>> d = AttributeDictionary() | |
>>> d["test"] = "a" | |
>>> print(d.test) # prints "a" | |
``` | |
""" | |
def __getattr__(self, k): | |
if k in self: | |
return self[k] | |
else: | |
raise AttributeError(k) | |
def __setattr__(self, k, v): | |
(self.__setitem__, super().__setattr__)[k[0] == "_"](k, v) | |
def __delattr__(self, k): | |
if k in self: | |
del self[k] | |
else: | |
raise AttributeError(k) | |
def __dir__(self): | |
keys = sorted(self.keys()) | |
keys = [key for key in keys if key.replace("_", "").isalpha()] | |
return super().__dir__() + keys | |
def __repr__(self): | |
repr_str = "Available Attributes or Keys:\n" | |
for key in sorted(self.keys()): | |
repr_str += f" * {key}" | |
if not key.replace("_", "").isalpha(): | |
repr_str += " (Key only)" | |
repr_str += "\n" | |
return repr_str | |
class GeneralTags(AttributeDictionary): | |
""" | |
A namespace object holding all tags, filtered by `keys` If a tag starts with | |
a number, it will only exist in the dictionary | |
Example: | |
```python | |
>>> a.b["1a"] # will work | |
>>> a["b"]["1a"] # will work | |
>>> # a.b.1a # will not work | |
``` | |
Args: | |
tag_dictionary (`dict`): | |
A dictionary of tags returned from the /api/***-tags-by-type api | |
endpoint | |
keys (`list`): | |
A list of keys to unpack the `tag_dictionary` with, such as | |
`["library","language"]` | |
""" | |
def __init__(self, tag_dictionary: dict, keys: Optional[list] = None): | |
self._tag_dictionary = tag_dictionary | |
if keys is None: | |
keys = list(self._tag_dictionary.keys()) | |
for key in keys: | |
self._unpack_and_assign_dictionary(key) | |
def _unpack_and_assign_dictionary(self, key: str): | |
"Assign nested attributes to `self.key` containing information as an `AttributeDictionary`" | |
ref = AttributeDictionary() | |
setattr(self, key, ref) | |
for item in self._tag_dictionary.get(key, []): | |
label = item["label"].replace(" ", "").replace("-", "_").replace(".", "_") | |
ref[label] = item["id"] | |
self[key] = ref | |
class ModelTags(GeneralTags): | |
""" | |
A namespace object holding all available model tags If a tag starts with a | |
number, it will only exist in the dictionary | |
Example: | |
```python | |
>>> a.dataset["1_5BArabicCorpus"] # will work | |
>>> a["dataset"]["1_5BArabicCorpus"] # will work | |
>>> # o.dataset.1_5BArabicCorpus # will not work | |
``` | |
Args: | |
model_tag_dictionary (`dict`): | |
A dictionary of valid model tags, returned from the | |
/api/models-tags-by-type api endpoint | |
""" | |
def __init__(self, model_tag_dictionary: dict): | |
keys = ["library", "language", "license", "dataset", "pipeline_tag"] | |
super().__init__(model_tag_dictionary, keys) | |
class DatasetTags(GeneralTags): | |
""" | |
A namespace object holding all available dataset tags If a tag starts with a | |
number, it will only exist in the dictionary | |
Example | |
```python | |
>>> a.size_categories["100K<n<1M"] # will work | |
>>> a["size_categories"]["100K<n<1M"] # will work | |
>>> # o.size_categories.100K<n<1M # will not work | |
``` | |
Args: | |
dataset_tag_dictionary (`dict`): | |
A dictionary of valid dataset tags, returned from the | |
/api/datasets-tags-by-type api endpoint | |
""" | |
def __init__(self, dataset_tag_dictionary: dict): | |
keys = [ | |
"language", | |
"multilinguality", | |
"language_creators", | |
"task_categories", | |
"size_categories", | |
"benchmark", | |
"task_ids", | |
"license", | |
] | |
super().__init__(dataset_tag_dictionary, keys) | |