""" File: model_translation.py Description: Loading models for text translations (EN->FR, FR->EN) Author: Didier Guillevic Date: 2024-03-16 """ import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM src_langs = set(["ar", "en", "fa", "fr", "he", "ja", "zh"]) model_names = { "ar": "Helsinki-NLP/opus-mt-ar-en", "en": "Helsinki-NLP/opus-mt-en-fr", "fa": "Helsinki-NLP/opus-mt-tc-big-fa-itc", "fr": "Helsinki-NLP/opus-mt-fr-en", "he": "Helsinki-NLP/opus-mt-tc-big-he-en", "ja": "Helsinki-NLP/opus-mt-jap-en", "zh": "Helsinki-NLP/opus-mt-zh-en", } # Registry for all loaded bilingual models tokenizer_model_registry = {} device = 'cuda' if torch.cuda.is_available() else 'cpu' def get_tokenizer_model_for_src_lang(src_lang: str) -> (AutoTokenizer, AutoModelForSeq2SeqLM): """ Return the (tokenizer, model) for a given source language. """ src_lang = src_lang.lower() # Already loaded? if src_lang in tokenizer_model_registry: return tokenizer_model_registry.get(src_lang) # Load tokenizer and model model_name = model_names.get(src_lang) if not model_name: raise Exception(f"No model defined for language: {src_lang}") # We will leave the models on the CPU (for now) tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSeq2SeqLM.from_pretrained(model_name) if model.config.torch_dtype != torch.float16: model = model.half() model = model.to(device) tokenizer_model_registry[src_lang] = (tokenizer, model) return (tokenizer, model) # Max number of words for given input text # - Usually 512 tokens (max position encodings, as well as max length) # - Let's set to some number of words somewhat lower than that threshold # - e.g. 200 words max_words_per_chunk = 200 # # Multilingual translation model # model_MADLAD_name = "google/madlad400-3b-mt" #model_MADLAD_name = "google/madlad400-7b-mt-bt" tokenizer_multilingual = AutoTokenizer.from_pretrained(model_MADLAD_name, use_fast=True) model_multilingual = AutoModelForSeq2SeqLM.from_pretrained( model_MADLAD_name, device_map="auto", torch_dtype=torch.float16)