File size: 4,243 Bytes
d7a594b
bce5ce9
8981122
d7a594b
9ced7bd
5fbdd3c
 
 
 
de9c8c4
5fbdd3c
 
 
 
 
 
 
 
 
bce5ce9
 
5fbdd3c
9ced7bd
 
5fbdd3c
de9c8c4
5fbdd3c
 
 
bce5ce9
 
 
 
 
5fbdd3c
d7a594b
5fbdd3c
9ced7bd
 
5fbdd3c
 
 
 
9ced7bd
 
5fbdd3c
 
 
 
9ced7bd
 
5fbdd3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7781f10
9ced7bd
d7a594b
9ced7bd
 
d7a594b
 
 
 
 
 
9ced7bd
 
4678c9b
d7a594b
9ced7bd
 
d7a594b
 
 
 
 
 
9ced7bd
 
5fbdd3c
 
bce5ce9
 
 
de9c8c4
bce5ce9
d7a594b
bce5ce9
 
d7a594b
 
de9c8c4
 
bce5ce9
 
721bf64
bce5ce9
 
 
 
 
721bf64
 
bce5ce9
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from huggingface_hub import hf_hub_download
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from shared import CustomTokens
from errors import ClassifierLoadError, ModelLoadError
from functools import lru_cache
import pickle
import os
from dataclasses import dataclass, field
from typing import Optional
import torch


@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """

    model_name_or_path: str = field(
        default=None,
        # default='google/t5-v1_1-small',  # t5-small
        metadata={
            'help': 'Path to pretrained model or model identifier from huggingface.co/models'
        }
    )

    # config_name: Optional[str] = field( # TODO remove?
    #     default=None, metadata={'help': 'Pretrained config name or path if not the same as model_name'}
    # )
    # tokenizer_name: Optional[str] = field(
    #     default=None, metadata={
    #         'help': 'Pretrained tokenizer name or path if not the same as model_name'
    #     }
    # )
    cache_dir: Optional[str] = field(
        default='models',
        metadata={
            'help': 'Where to store the pretrained models downloaded from huggingface.co'
        },
    )
    use_fast_tokenizer: bool = field(  # TODO remove?
        default=True,
        metadata={
            'help': 'Whether to use one of the fast tokenizer (backed by the tokenizers library) or not.'
        },
    )
    model_revision: str = field(  # TODO remove?
        default='main',
        metadata={
            'help': 'The specific model version to use (can be a branch name, tag name or commit id).'
        },
    )
    use_auth_token: bool = field(
        default=False,
        metadata={
            'help': 'Will use the token generated when running `transformers-cli login` (necessary to use this script '
            'with private models).'
        },
    )
    resize_position_embeddings: Optional[bool] = field(
        default=None,
        metadata={
            'help': "Whether to automatically resize the position embeddings if `max_source_length` exceeds the model's position embeddings."
        },
    )


@lru_cache(maxsize=None)
def get_classifier_vectorizer(classifier_args):
    # Classifier
    classifier_path = os.path.join(
        classifier_args.classifier_dir, classifier_args.classifier_file)
    if not os.path.exists(classifier_path):
        hf_hub_download(repo_id=classifier_args.classifier_model,
                        filename=classifier_args.classifier_file,
                        cache_dir=classifier_args.classifier_dir,
                        force_filename=classifier_args.classifier_file,
                        )
    with open(classifier_path, 'rb') as fp:
        classifier = pickle.load(fp)

    # Vectorizer
    vectorizer_path = os.path.join(
        classifier_args.classifier_dir, classifier_args.vectorizer_file)
    if not os.path.exists(vectorizer_path):
        hf_hub_download(repo_id=classifier_args.classifier_model,
                        filename=classifier_args.vectorizer_file,
                        cache_dir=classifier_args.classifier_dir,
                        force_filename=classifier_args.vectorizer_file,
                        )
    with open(vectorizer_path, 'rb') as fp:
        vectorizer = pickle.load(fp)

    return classifier, vectorizer


@lru_cache(maxsize=None)
def get_model_tokenizer(model_name_or_path, cache_dir=None, no_cuda=False):
    if model_name_or_path is None:
        raise ModelLoadError('Invalid model_name_or_path.')

    # Load pretrained model and tokenizer
    model = AutoModelForSeq2SeqLM.from_pretrained(
        model_name_or_path, cache_dir=cache_dir)
    if not no_cuda:
        model.to('cuda' if torch.cuda.is_available() else 'cpu')

    tokenizer = AutoTokenizer.from_pretrained(
        model_name_or_path, cache_dir=cache_dir)

    # Ensure model and tokenizer contain the custom tokens
    CustomTokens.add_custom_tokens(tokenizer)
    model.resize_token_embeddings(len(tokenizer))

    # TODO find a way to adjust based on model's input size
    # print('tokenizer.model_max_length', tokenizer.model_max_length)

    return model, tokenizer