File size: 1,639 Bytes
d4577f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import json
from huggingface_hub import HfApi, hf_hub_url, HfFolder
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
from requests import HTTPError

from app.config import HF_TOKEN


def is_model_on_hub(
    model_id: str, revision: str, token: str = HF_TOKEN, trust_remote_code: bool = False, test_tokenizer=True
) -> (bool, str, dict):
    """Checks if a model is on the hub.
    Returns:
        (bool, str, dict): a tuple with a boolean indicating if the model is on the hub, a string with the error message, and the model config
    """
    if not token:
        return (
            False,
            "No Hugging Face token provided. Please create a read token on the Hugging Face website and add it as a secret with the name `HF_TOKEN`.",
            None,
        )

    api = HfApi(token=token)
    try:
        model_info = api.model_info(model_id, revision=revision)
        model_config = None
        if hasattr(model_info, "config"):
            model_config = model_info.config
    except RepositoryNotFoundError:
        return False, f"Model {model_id} not found on hub", None
    except (HTTPError, GatedRepoError) as e:
        return False, f"Model {model_id} is gated, you need to accept the license agreement first.", None

    if trust_remote_code and test_tokenizer:
        from transformers import AutoTokenizer

        try:
            AutoTokenizer.from_pretrained(model_id, revision=revision, trust_remote_code=True, token=token)
        except Exception as e:
            return False, f"Could not load tokenizer for {model_id}. Error: {e}", None

    return True, "", model_config