Spaces:
Sleeping
Sleeping
File size: 3,461 Bytes
fe02c49 eec4fa3 fe02c49 706408b fe02c49 706408b 77364cc 706408b 77364cc 706408b 150301e 706408b fe02c49 efcd81a fe02c49 efcd81a fe02c49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
"""
File: model_translation.py
Description:
Loading models for text translations
Author: Didier Guillevic
Date: 2024-03-16
"""
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
class Singleton(type):
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
class ModelM2M100(metaclass=Singleton):
"""Loads an instance of the M2M100 model.
"""
def __init__(self):
self._model_name = "facebook/m2m100_1.2B"
self._tokenizer = M2M100Tokenizer.from_pretrained(self._model_name)
self._model = M2M100ForConditionalGeneration.from_pretrained(
self._model_name,
device_map="auto",
torch_dtype=torch.float16,
low_cpu_mem_usage=True
)
@property
def model_name(self):
return self._model_name
@property
def tokenizer(self):
return self._tokenizer
@property
def model(self):
return self._model
class ModelMADLAD(metaclass=Singleton):
"""Loads an instance of the Google MADLAD model (3B).
"""
def __init__(self):
self._model_name = "google/madlad400-3b-mt"
self._tokenizer = AutoTokenizer.from_pretrained(
self.model_name, use_fast=True
)
self._model = AutoModelForSeq2SeqLM.from_pretrained(
self._model_name,
device_map="auto",
torch_dtype=torch.float16,
low_cpu_mem_usage=True
)
@property
def model_name(self):
return self._model_name
@property
def tokenizer(self):
return self._tokenizer
@property
def model(self):
return self._model
# Bi-lingual individual models
src_langs = set(["ar", "en", "fa", "fr", "he", "ja", "zh"])
model_names = {
"ar": "Helsinki-NLP/opus-mt-ar-en",
"en": "Helsinki-NLP/opus-mt-en-fr",
"fa": "Helsinki-NLP/opus-mt-tc-big-fa-itc",
"fr": "Helsinki-NLP/opus-mt-fr-en",
"he": "Helsinki-NLP/opus-mt-tc-big-he-en",
"zh": "Helsinki-NLP/opus-mt-zh-en",
}
# Registry for all loaded bilingual models
tokenizer_model_registry = {}
device = 'cpu'
def get_tokenizer_model_for_src_lang(src_lang: str) -> (AutoTokenizer, AutoModelForSeq2SeqLM):
"""
Return the (tokenizer, model) for a given source language.
"""
src_lang = src_lang.lower()
# Already loaded?
if src_lang in tokenizer_model_registry:
return tokenizer_model_registry.get(src_lang)
# Load tokenizer and model
model_name = model_names.get(src_lang)
if not model_name:
raise Exception(f"No model defined for language: {src_lang}")
# We will leave the models on the CPU (for now)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
if model.config.torch_dtype != torch.float16:
model = model.half()
model.to(device)
tokenizer_model_registry[src_lang] = (tokenizer, model)
return (tokenizer, model)
# Max number of words for given input text
# - Usually 512 tokens (max position encodings, as well as max length)
# - Let's set to some number of words somewhat lower than that threshold
# - e.g. 200 words
max_words_per_chunk = 200
|