Spaces:
Sleeping
Sleeping
""" | |
File: model_translation.py | |
Description: | |
Loading models for text translations | |
Author: Didier Guillevic | |
Date: 2024-03-16 | |
""" | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration | |
from transformers import BitsAndBytesConfig | |
quantization_config = BitsAndBytesConfig( | |
load_in_8bit=True, | |
llm_int8_threshold=200.0 # https://discuss.huggingface.co/t/correct-usage-of-bitsandbytesconfig/33809/5 | |
) | |
class Singleton(type): | |
_instances = {} | |
def __call__(cls, *args, **kwargs): | |
if cls not in cls._instances: | |
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) | |
return cls._instances[cls] | |
class ModelM2M100(metaclass=Singleton): | |
"""Loads an instance of the M2M100 model. | |
""" | |
def __init__(self): | |
self._model_name = "facebook/m2m100_1.2B" | |
self._tokenizer = M2M100Tokenizer.from_pretrained(self._model_name) | |
self._model = M2M100ForConditionalGeneration.from_pretrained( | |
self._model_name, | |
device_map="auto", | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True | |
) | |
self._model = torch.compile(self._model) | |
def model_name(self): | |
return self._model_name | |
def tokenizer(self): | |
return self._tokenizer | |
def model(self): | |
return self._model | |
class ModelMADLAD(metaclass=Singleton): | |
"""Loads an instance of the Google MADLAD model (3B). | |
""" | |
def __init__(self): | |
self._model_name = "google/madlad400-3b-mt" | |
self._tokenizer = AutoTokenizer.from_pretrained( | |
self.model_name, use_fast=True | |
) | |
self._model = AutoModelForSeq2SeqLM.from_pretrained( | |
self._model_name, | |
device_map="auto", | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True, | |
quantization_config=quantization_config | |
) | |
self._model = torch.compile(self._model) | |
def model_name(self): | |
return self._model_name | |
def tokenizer(self): | |
return self._tokenizer | |
def model(self): | |
return self._model | |
# Bi-lingual individual models | |
src_langs = set(["ar", "en", "fa", "fr", "he", "ja", "zh"]) | |
model_names = { | |
"ar": "Helsinki-NLP/opus-mt-ar-en", | |
"en": "Helsinki-NLP/opus-mt-en-fr", | |
"fa": "Helsinki-NLP/opus-mt-tc-big-fa-itc", | |
"fr": "Helsinki-NLP/opus-mt-fr-en", | |
"he": "Helsinki-NLP/opus-mt-tc-big-he-en", | |
"zh": "Helsinki-NLP/opus-mt-zh-en", | |
} | |
# Registry for all loaded bilingual models | |
tokenizer_model_registry = {} | |
device = 'cpu' | |
def get_tokenizer_model_for_src_lang(src_lang: str) -> (AutoTokenizer, AutoModelForSeq2SeqLM): | |
""" | |
Return the (tokenizer, model) for a given source language. | |
""" | |
src_lang = src_lang.lower() | |
# Already loaded? | |
if src_lang in tokenizer_model_registry: | |
return tokenizer_model_registry.get(src_lang) | |
# Load tokenizer and model | |
model_name = model_names.get(src_lang) | |
if not model_name: | |
raise Exception(f"No model defined for language: {src_lang}") | |
# We will leave the models on the CPU (for now) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
if model.config.torch_dtype != torch.float16: | |
model = model.half() | |
model.to(device) | |
tokenizer_model_registry[src_lang] = (tokenizer, model) | |
return (tokenizer, model) | |
# Max number of words for given input text | |
# - Usually 512 tokens (max position encodings, as well as max length) | |
# - Let's set to some number of words somewhat lower than that threshold | |
# - e.g. 200 words | |
max_words_per_chunk = 200 | |