Spaces:
Sleeping
Sleeping
# coding=utf-8 | |
# Copyright 2019-present, the HuggingFace Inc. team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import io | |
import os | |
from os.path import expanduser | |
from typing import Dict, List, Optional, Tuple | |
from tqdm import tqdm | |
import requests | |
ENDPOINT = "https://huggingface.co" | |
class RepoObj: | |
""" | |
HuggingFace git-based system, data structure that represents a file belonging to the current user. | |
""" | |
def __init__(self, filename: str, lastModified: str, commit: str, size: int, **kwargs): | |
self.filename = filename | |
self.lastModified = lastModified | |
self.commit = commit | |
self.size = size | |
class ModelSibling: | |
""" | |
Data structure that represents a public file inside a model, accessible from huggingface.co | |
""" | |
def __init__(self, rfilename: str, **kwargs): | |
self.rfilename = rfilename # filename relative to the model root | |
for k, v in kwargs.items(): | |
setattr(self, k, v) | |
class ModelInfo: | |
""" | |
Info about a public model accessible from huggingface.co | |
""" | |
def __init__( | |
self, | |
modelId: Optional[str] = None, # id of model | |
tags: List[str] = [], | |
pipeline_tag: Optional[str] = None, | |
siblings: Optional[List[Dict]] = None, # list of files that constitute the model | |
**kwargs | |
): | |
self.modelId = modelId | |
self.tags = tags | |
self.pipeline_tag = pipeline_tag | |
self.siblings = [ModelSibling(**x) for x in siblings] if siblings is not None else None | |
for k, v in kwargs.items(): | |
setattr(self, k, v) | |
class HfApi: | |
def __init__(self, endpoint=None): | |
self.endpoint = endpoint if endpoint is not None else ENDPOINT | |
def login(self, username: str, password: str) -> str: | |
""" | |
Call HF API to sign in a user and get a token if credentials are valid. | |
Outputs: token if credentials are valid | |
Throws: requests.exceptions.HTTPError if credentials are invalid | |
""" | |
path = f"{self.endpoint}/api/login" | |
r = requests.post(path, json={"username": username, "password": password}) | |
r.raise_for_status() | |
d = r.json() | |
return d["token"] | |
def whoami(self, token: str) -> Tuple[str, List[str]]: | |
""" | |
Call HF API to know "whoami" | |
""" | |
path = f"{self.endpoint}/api/whoami" | |
r = requests.get(path, headers={"authorization": f"Bearer {token}"}) | |
r.raise_for_status() | |
d = r.json() | |
return d["user"], d["orgs"] | |
def logout(self, token: str) -> None: | |
""" | |
Call HF API to log out. | |
""" | |
path = f"{self.endpoint}/api/logout" | |
r = requests.post(path, headers={"authorization": f"Bearer {token}"}) | |
r.raise_for_status() | |
def model_list(self) -> List[ModelInfo]: | |
""" | |
Get the public list of all the models on huggingface.co | |
""" | |
path = f"{self.endpoint}/api/models" | |
r = requests.get(path) | |
r.raise_for_status() | |
d = r.json() | |
return [ModelInfo(**x) for x in d] | |
def list_repos_objs(self, token: str, organization: Optional[str] = None) -> List[RepoObj]: | |
""" | |
HuggingFace git-based system, used for models. | |
Call HF API to list all stored files for user (or one of their organizations). | |
""" | |
path = f"{self.endpoint}/api/repos/ls" | |
params = {"organization": organization} if organization is not None else None | |
r = requests.get(path, params=params, headers={"authorization": f"Bearer {token}"}) | |
r.raise_for_status() | |
d = r.json() | |
return [RepoObj(**x) for x in d] | |
def create_repo( | |
self, | |
token: str, | |
name: str, | |
organization: Optional[str] = None, | |
private: Optional[bool] = None, | |
exist_ok=False, | |
lfsmultipartthresh: Optional[int] = None, | |
) -> str: | |
""" | |
HuggingFace git-based system, used for models. | |
Call HF API to create a whole repo. | |
Params: | |
private: Whether the model repo should be private (requires a paid huggingface.co account) | |
exist_ok: Do not raise an error if repo already exists | |
lfsmultipartthresh: Optional: internal param for testing purposes. | |
""" | |
path = f"{self.endpoint}/api/repos/create" | |
json = {"name": name, "organization": organization, "private": private} | |
if lfsmultipartthresh is not None: | |
json["lfsmultipartthresh"] = lfsmultipartthresh | |
r = requests.post( | |
path, | |
headers={"authorization": f"Bearer {token}"}, | |
json=json, | |
) | |
if exist_ok and r.status_code == 409: | |
return "" | |
r.raise_for_status() | |
d = r.json() | |
return d["url"] | |
def delete_repo(self, token: str, name: str, organization: Optional[str] = None): | |
""" | |
HuggingFace git-based system, used for models. | |
Call HF API to delete a whole repo. | |
CAUTION(this is irreversible). | |
""" | |
path = f"{self.endpoint}/api/repos/delete" | |
r = requests.delete( | |
path, | |
headers={"authorization": f"Bearer {token}"}, | |
json={"name": name, "organization": organization}, | |
) | |
r.raise_for_status() | |
class TqdmProgressFileReader: | |
""" | |
Wrap an io.BufferedReader `f` (such as the output of `open(…, "rb")`) and override `f.read()` so as to display a | |
tqdm progress bar. | |
see github.com/huggingface/transformers/pull/2078#discussion_r354739608 for implementation details. | |
""" | |
def __init__(self, f: io.BufferedReader): | |
self.f = f | |
self.total_size = os.fstat(f.fileno()).st_size | |
self.pbar = tqdm(total=self.total_size, leave=False) | |
self.read = f.read | |
f.read = self._read | |
def _read(self, n=-1): | |
self.pbar.update(n) | |
return self.read(n) | |
def close(self): | |
self.pbar.close() | |
class HfFolder: | |
path_token = expanduser("~/.huggingface/token") | |
def save_token(cls, token): | |
""" | |
Save token, creating folder as needed. | |
""" | |
os.makedirs(os.path.dirname(cls.path_token), exist_ok=True) | |
with open(cls.path_token, "w+") as f: | |
f.write(token) | |
def get_token(cls): | |
""" | |
Get token or None if not existent. | |
""" | |
try: | |
with open(cls.path_token, "r") as f: | |
return f.read() | |
except FileNotFoundError: | |
pass | |
def delete_token(cls): | |
""" | |
Delete token. Do not fail if token does not exist. | |
""" | |
try: | |
os.remove(cls.path_token) | |
except FileNotFoundError: | |
pass | |