File size: 3,831 Bytes
fe02c49
 
 
 
eec4fa3
fe02c49
 
 
 
 
 
 
706408b
ea7bc2f
706408b
ea7bc2f
 
 
 
706408b
 
 
 
 
 
 
fe02c49
706408b
77364cc
706408b
 
77364cc
706408b
 
 
 
 
 
 
ea7bc2f
706408b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150301e
b929bff
706408b
 
 
 
 
 
 
ea7bc2f
 
706408b
ea7bc2f
706408b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe02c49
 
 
 
 
 
 
 
 
 
 
 
 
efcd81a
fe02c49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efcd81a
fe02c49
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""
File: model_translation.py

Description: 
   Loading models for text translations

Author: Didier Guillevic
Date: 2024-03-16
"""

import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
from transformers import BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(
    load_in_8bit=True,
    llm_int8_threshold=200.0 # https://discuss.huggingface.co/t/correct-usage-of-bitsandbytesconfig/33809/5
)

class Singleton(type):
    _instances = {}
    def __call__(cls, *args, **kwargs):
        if cls not in cls._instances:
            cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
        return cls._instances[cls]

class ModelM2M100(metaclass=Singleton):
    """Loads an instance of the M2M100 model.
    """
    def __init__(self):
        self._model_name = "facebook/m2m100_1.2B"
        self._tokenizer = M2M100Tokenizer.from_pretrained(self._model_name)
        self._model = M2M100ForConditionalGeneration.from_pretrained(
            self._model_name,
            device_map="auto",
            torch_dtype=torch.float16,
            low_cpu_mem_usage=True
        )
        self._model = torch.compile(self._model)
    
    @property
    def model_name(self):
        return self._model_name

    @property
    def tokenizer(self):
        return self._tokenizer

    @property
    def model(self):
        return self._model

class ModelMADLAD(metaclass=Singleton):
    """Loads an instance of the Google MADLAD model (3B).
    """
    def __init__(self):
        self._model_name = "google/madlad400-3b-mt"
        self._tokenizer = AutoTokenizer.from_pretrained(
            self.model_name, use_fast=True
        )
        self._model = AutoModelForSeq2SeqLM.from_pretrained(
            self._model_name,
            device_map="auto",
            torch_dtype=torch.float16,
            low_cpu_mem_usage=True,
            quantization_config=quantization_config
        )
        self._model = torch.compile(self._model)
    
    @property
    def model_name(self):
        return self._model_name
    
    @property
    def tokenizer(self):
        return self._tokenizer

    @property
    def model(self):
        return self._model


# Bi-lingual individual models
src_langs = set(["ar", "en", "fa", "fr", "he", "ja", "zh"])
model_names = {
    "ar": "Helsinki-NLP/opus-mt-ar-en",
    "en": "Helsinki-NLP/opus-mt-en-fr",
    "fa": "Helsinki-NLP/opus-mt-tc-big-fa-itc",
    "fr": "Helsinki-NLP/opus-mt-fr-en",
    "he": "Helsinki-NLP/opus-mt-tc-big-he-en",
    "zh": "Helsinki-NLP/opus-mt-zh-en",
}

# Registry for all loaded bilingual models
tokenizer_model_registry = {}

device = 'cpu'

def get_tokenizer_model_for_src_lang(src_lang: str) -> (AutoTokenizer, AutoModelForSeq2SeqLM):
    """
    Return the (tokenizer, model) for a given source language.
    """
    src_lang = src_lang.lower()

    # Already loaded?
    if src_lang in tokenizer_model_registry:
        return tokenizer_model_registry.get(src_lang)

    # Load tokenizer and model
    model_name = model_names.get(src_lang)
    if not model_name:
        raise Exception(f"No model defined for language: {src_lang}")
    
    # We will leave the models on the CPU (for now)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    if model.config.torch_dtype != torch.float16:
        model = model.half()
    model.to(device)
    tokenizer_model_registry[src_lang] = (tokenizer, model)

    return (tokenizer, model)

# Max number of words for given input text
# - Usually 512 tokens (max position encodings, as well as max length)
# - Let's set to some number of words somewhat lower than that threshold
# - e.g. 200 words
max_words_per_chunk = 200