Text_translation / model_translation.py
Didier's picture
Initial commit: bilingual models, multilingual mode, Google Translate
fe02c49
raw
history blame
2.19 kB
"""
File: model_translation.py
Description:
Loading models for text translations (EN->FR, FR->EN)
Author: Didier Guillevic
Date: 2024-03-16
"""
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
src_langs = set(["ar", "en", "fa", "fr", "he", "ja", "zh"])
model_names = {
"ar": "Helsinki-NLP/opus-mt-ar-en",
"en": "Helsinki-NLP/opus-mt-en-fr",
"fa": "Helsinki-NLP/opus-mt-tc-big-fa-itc",
"fr": "Helsinki-NLP/opus-mt-fr-en",
"he": "Helsinki-NLP/opus-mt-tc-big-he-en",
"ja": "Helsinki-NLP/opus-mt-jap-en",
"zh": "Helsinki-NLP/opus-mt-zh-en",
}
# Registry for all loaded bilingual models
tokenizer_model_registry = {}
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def get_tokenizer_model_for_src_lang(src_lang: str) -> (AutoTokenizer, AutoModelForSeq2SeqLM):
"""
Return the (tokenizer, model) for a given source language.
"""
src_lang = src_lang.lower()
# Already loaded?
if src_lang in tokenizer_model_registry:
return tokenizer_model_registry.get(src_lang)
# Load tokenizer and model
model_name = model_names.get(src_lang)
if not model_name:
raise Exception(f"No model defined for language: {src_lang}")
# We will leave the models on the CPU (for now)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
if model.config.torch_dtype != torch.float16:
model = model.half()
model = model.to(device)
tokenizer_model_registry[src_lang] = (tokenizer, model)
return (tokenizer, model)
# Max number of words for given input text
# - Usually 512 tokens (max position encodings, as well as max length)
# - Let's set to some number of words somewhat lower than that threshold
# - e.g. 200 words
max_words_per_chunk = 200
#
# Multilingual translation model
#
model_MADLAD_name = "google/madlad400-3b-mt"
#model_MADLAD_name = "google/madlad400-7b-mt-bt"
tokenizer_multilingual = AutoTokenizer.from_pretrained(model_MADLAD_name, use_fast=True)
model_multilingual = AutoModelForSeq2SeqLM.from_pretrained(
model_MADLAD_name, device_map="auto", torch_dtype=torch.float16)