Text_translation / model_translation.py
Didier's picture
Removing m2m100 and quantizing MADLAD in 8-bit (as GPU resources limited)
ea7bc2f
raw
history blame
3.83 kB
"""
File: model_translation.py
Description:
Loading models for text translations
Author: Didier Guillevic
Date: 2024-03-16
"""
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
from transformers import BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=200.0 # https://discuss.huggingface.co/t/correct-usage-of-bitsandbytesconfig/33809/5
)
class Singleton(type):
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
class ModelM2M100(metaclass=Singleton):
"""Loads an instance of the M2M100 model.
"""
def __init__(self):
self._model_name = "facebook/m2m100_1.2B"
self._tokenizer = M2M100Tokenizer.from_pretrained(self._model_name)
self._model = M2M100ForConditionalGeneration.from_pretrained(
self._model_name,
device_map="auto",
torch_dtype=torch.float16,
low_cpu_mem_usage=True
)
self._model = torch.compile(self._model)
@property
def model_name(self):
return self._model_name
@property
def tokenizer(self):
return self._tokenizer
@property
def model(self):
return self._model
class ModelMADLAD(metaclass=Singleton):
"""Loads an instance of the Google MADLAD model (3B).
"""
def __init__(self):
self._model_name = "google/madlad400-3b-mt"
self._tokenizer = AutoTokenizer.from_pretrained(
self.model_name, use_fast=True
)
self._model = AutoModelForSeq2SeqLM.from_pretrained(
self._model_name,
device_map="auto",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
quantization_config=quantization_config
)
self._model = torch.compile(self._model)
@property
def model_name(self):
return self._model_name
@property
def tokenizer(self):
return self._tokenizer
@property
def model(self):
return self._model
# Bi-lingual individual models
src_langs = set(["ar", "en", "fa", "fr", "he", "ja", "zh"])
model_names = {
"ar": "Helsinki-NLP/opus-mt-ar-en",
"en": "Helsinki-NLP/opus-mt-en-fr",
"fa": "Helsinki-NLP/opus-mt-tc-big-fa-itc",
"fr": "Helsinki-NLP/opus-mt-fr-en",
"he": "Helsinki-NLP/opus-mt-tc-big-he-en",
"zh": "Helsinki-NLP/opus-mt-zh-en",
}
# Registry for all loaded bilingual models
tokenizer_model_registry = {}
device = 'cpu'
def get_tokenizer_model_for_src_lang(src_lang: str) -> (AutoTokenizer, AutoModelForSeq2SeqLM):
"""
Return the (tokenizer, model) for a given source language.
"""
src_lang = src_lang.lower()
# Already loaded?
if src_lang in tokenizer_model_registry:
return tokenizer_model_registry.get(src_lang)
# Load tokenizer and model
model_name = model_names.get(src_lang)
if not model_name:
raise Exception(f"No model defined for language: {src_lang}")
# We will leave the models on the CPU (for now)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
if model.config.torch_dtype != torch.float16:
model = model.half()
model.to(device)
tokenizer_model_registry[src_lang] = (tokenizer, model)
return (tokenizer, model)
# Max number of words for given input text
# - Usually 512 tokens (max position encodings, as well as max length)
# - Let's set to some number of words somewhat lower than that threshold
# - e.g. 200 words
max_words_per_chunk = 200