Spaces:
Sleeping
Sleeping
""" | |
File: model_translation.py | |
Description: | |
Loading models for text translations (EN->FR, FR->EN) | |
Author: Didier Guillevic | |
Date: 2024-03-16 | |
""" | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
src_langs = set(["ar", "en", "fa", "fr", "he", "ja", "zh"]) | |
model_names = { | |
"ar": "Helsinki-NLP/opus-mt-ar-en", | |
"en": "Helsinki-NLP/opus-mt-en-fr", | |
"fa": "Helsinki-NLP/opus-mt-tc-big-fa-itc", | |
"fr": "Helsinki-NLP/opus-mt-fr-en", | |
"he": "Helsinki-NLP/opus-mt-tc-big-he-en", | |
"ja": "Helsinki-NLP/opus-mt-jap-en", | |
"zh": "Helsinki-NLP/opus-mt-zh-en", | |
} | |
# Registry for all loaded bilingual models | |
tokenizer_model_registry = {} | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
def get_tokenizer_model_for_src_lang(src_lang: str) -> (AutoTokenizer, AutoModelForSeq2SeqLM): | |
""" | |
Return the (tokenizer, model) for a given source language. | |
""" | |
src_lang = src_lang.lower() | |
# Already loaded? | |
if src_lang in tokenizer_model_registry: | |
return tokenizer_model_registry.get(src_lang) | |
# Load tokenizer and model | |
model_name = model_names.get(src_lang) | |
if not model_name: | |
raise Exception(f"No model defined for language: {src_lang}") | |
# We will leave the models on the CPU (for now) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
if model.config.torch_dtype != torch.float16: | |
model = model.half() | |
model = model.to(device) | |
tokenizer_model_registry[src_lang] = (tokenizer, model) | |
return (tokenizer, model) | |
# Max number of words for given input text | |
# - Usually 512 tokens (max position encodings, as well as max length) | |
# - Let's set to some number of words somewhat lower than that threshold | |
# - e.g. 200 words | |
max_words_per_chunk = 200 | |
# | |
# Multilingual translation model | |
# | |
model_MADLAD_name = "google/madlad400-3b-mt" | |
#model_MADLAD_name = "google/madlad400-7b-mt-bt" | |
tokenizer_multilingual = AutoTokenizer.from_pretrained(model_MADLAD_name, use_fast=True) | |
model_multilingual = AutoModelForSeq2SeqLM.from_pretrained( | |
model_MADLAD_name, device_map="auto", torch_dtype=torch.float16) | |