Spaces:
Sleeping
Sleeping
File size: 3,831 Bytes
fe02c49 eec4fa3 fe02c49 706408b ea7bc2f 706408b ea7bc2f 706408b fe02c49 706408b 77364cc 706408b 77364cc 706408b ea7bc2f 706408b 150301e b929bff 706408b ea7bc2f 706408b ea7bc2f 706408b fe02c49 efcd81a fe02c49 efcd81a fe02c49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
"""
File: model_translation.py
Description:
Loading models for text translations
Author: Didier Guillevic
Date: 2024-03-16
"""
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
from transformers import BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=200.0 # https://discuss.huggingface.co/t/correct-usage-of-bitsandbytesconfig/33809/5
)
class Singleton(type):
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
class ModelM2M100(metaclass=Singleton):
"""Loads an instance of the M2M100 model.
"""
def __init__(self):
self._model_name = "facebook/m2m100_1.2B"
self._tokenizer = M2M100Tokenizer.from_pretrained(self._model_name)
self._model = M2M100ForConditionalGeneration.from_pretrained(
self._model_name,
device_map="auto",
torch_dtype=torch.float16,
low_cpu_mem_usage=True
)
self._model = torch.compile(self._model)
@property
def model_name(self):
return self._model_name
@property
def tokenizer(self):
return self._tokenizer
@property
def model(self):
return self._model
class ModelMADLAD(metaclass=Singleton):
"""Loads an instance of the Google MADLAD model (3B).
"""
def __init__(self):
self._model_name = "google/madlad400-3b-mt"
self._tokenizer = AutoTokenizer.from_pretrained(
self.model_name, use_fast=True
)
self._model = AutoModelForSeq2SeqLM.from_pretrained(
self._model_name,
device_map="auto",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
quantization_config=quantization_config
)
self._model = torch.compile(self._model)
@property
def model_name(self):
return self._model_name
@property
def tokenizer(self):
return self._tokenizer
@property
def model(self):
return self._model
# Bi-lingual individual models
src_langs = set(["ar", "en", "fa", "fr", "he", "ja", "zh"])
model_names = {
"ar": "Helsinki-NLP/opus-mt-ar-en",
"en": "Helsinki-NLP/opus-mt-en-fr",
"fa": "Helsinki-NLP/opus-mt-tc-big-fa-itc",
"fr": "Helsinki-NLP/opus-mt-fr-en",
"he": "Helsinki-NLP/opus-mt-tc-big-he-en",
"zh": "Helsinki-NLP/opus-mt-zh-en",
}
# Registry for all loaded bilingual models
tokenizer_model_registry = {}
device = 'cpu'
def get_tokenizer_model_for_src_lang(src_lang: str) -> (AutoTokenizer, AutoModelForSeq2SeqLM):
"""
Return the (tokenizer, model) for a given source language.
"""
src_lang = src_lang.lower()
# Already loaded?
if src_lang in tokenizer_model_registry:
return tokenizer_model_registry.get(src_lang)
# Load tokenizer and model
model_name = model_names.get(src_lang)
if not model_name:
raise Exception(f"No model defined for language: {src_lang}")
# We will leave the models on the CPU (for now)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
if model.config.torch_dtype != torch.float16:
model = model.half()
model.to(device)
tokenizer_model_registry[src_lang] = (tokenizer, model)
return (tokenizer, model)
# Max number of words for given input text
# - Usually 512 tokens (max position encodings, as well as max length)
# - Let's set to some number of words somewhat lower than that threshold
# - e.g. 200 words
max_words_per_chunk = 200
|