Spaces:
Sleeping
Sleeping
File size: 2,189 Bytes
fe02c49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
"""
File: model_translation.py
Description:
Loading models for text translations (EN->FR, FR->EN)
Author: Didier Guillevic
Date: 2024-03-16
"""
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
src_langs = set(["ar", "en", "fa", "fr", "he", "ja", "zh"])
model_names = {
"ar": "Helsinki-NLP/opus-mt-ar-en",
"en": "Helsinki-NLP/opus-mt-en-fr",
"fa": "Helsinki-NLP/opus-mt-tc-big-fa-itc",
"fr": "Helsinki-NLP/opus-mt-fr-en",
"he": "Helsinki-NLP/opus-mt-tc-big-he-en",
"ja": "Helsinki-NLP/opus-mt-jap-en",
"zh": "Helsinki-NLP/opus-mt-zh-en",
}
# Registry for all loaded bilingual models
tokenizer_model_registry = {}
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def get_tokenizer_model_for_src_lang(src_lang: str) -> (AutoTokenizer, AutoModelForSeq2SeqLM):
"""
Return the (tokenizer, model) for a given source language.
"""
src_lang = src_lang.lower()
# Already loaded?
if src_lang in tokenizer_model_registry:
return tokenizer_model_registry.get(src_lang)
# Load tokenizer and model
model_name = model_names.get(src_lang)
if not model_name:
raise Exception(f"No model defined for language: {src_lang}")
# We will leave the models on the CPU (for now)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
if model.config.torch_dtype != torch.float16:
model = model.half()
model = model.to(device)
tokenizer_model_registry[src_lang] = (tokenizer, model)
return (tokenizer, model)
# Max number of words for given input text
# - Usually 512 tokens (max position encodings, as well as max length)
# - Let's set to some number of words somewhat lower than that threshold
# - e.g. 200 words
max_words_per_chunk = 200
#
# Multilingual translation model
#
model_MADLAD_name = "google/madlad400-3b-mt"
#model_MADLAD_name = "google/madlad400-7b-mt-bt"
tokenizer_multilingual = AutoTokenizer.from_pretrained(model_MADLAD_name, use_fast=True)
model_multilingual = AutoModelForSeq2SeqLM.from_pretrained(
model_MADLAD_name, device_map="auto", torch_dtype=torch.float16)
|