Text_translation / model_translation.py
Didier's picture
Update model_translation.py
eec4fa3 verified
raw
history blame
2.6 kB
"""
File: model_translation.py
Description:
Loading models for text translations
Author: Didier Guillevic
Date: 2024-03-16
"""
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
src_langs = set(["ar", "en", "fa", "fr", "he", "ja", "zh"])
model_names = {
"ar": "Helsinki-NLP/opus-mt-ar-en",
"en": "Helsinki-NLP/opus-mt-en-fr",
"fa": "Helsinki-NLP/opus-mt-tc-big-fa-itc",
"fr": "Helsinki-NLP/opus-mt-fr-en",
"he": "Helsinki-NLP/opus-mt-tc-big-he-en",
"ja": "Helsinki-NLP/opus-mt-jap-en",
"zh": "Helsinki-NLP/opus-mt-zh-en",
}
# Registry for all loaded bilingual models
tokenizer_model_registry = {}
device = 'cpu'
def get_tokenizer_model_for_src_lang(src_lang: str) -> (AutoTokenizer, AutoModelForSeq2SeqLM):
"""
Return the (tokenizer, model) for a given source language.
"""
src_lang = src_lang.lower()
# Already loaded?
if src_lang in tokenizer_model_registry:
return tokenizer_model_registry.get(src_lang)
# Load tokenizer and model
model_name = model_names.get(src_lang)
if not model_name:
raise Exception(f"No model defined for language: {src_lang}")
# We will leave the models on the CPU (for now)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
if model.config.torch_dtype != torch.float16:
model = model.half()
model.to(device)
tokenizer_model_registry[src_lang] = (tokenizer, model)
return (tokenizer, model)
# Max number of words for given input text
# - Usually 512 tokens (max position encodings, as well as max length)
# - Let's set to some number of words somewhat lower than that threshold
# - e.g. 200 words
max_words_per_chunk = 200
#
# Multilingual language pairs
#
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
model_name_m2m100 = "facebook/m2m100_418M"
tokenizer_m2m100 = M2M100Tokenizer.from_pretrained(model_name_m2m100)
model_m2m100 = M2M100ForConditionalGeneration.from_pretrained(
model_name_m2m100,
device_map="auto",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
load_in_8bit=True
)
#
# Multilingual translation model
#
model_MADLAD_name = "google/madlad400-3b-mt"
#model_MADLAD_name = "google/madlad400-7b-mt-bt"
tokenizer_multilingual = AutoTokenizer.from_pretrained(model_MADLAD_name, use_fast=True)
model_multilingual = AutoModelForSeq2SeqLM.from_pretrained(
model_MADLAD_name,
device_map="auto",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
load_in_8bit=True
)