|
import torch
|
|
from transformers import AutoModelForSeq2SeqLM, BitsAndBytesConfig, AutoTokenizer
|
|
from IndicTransToolkit import IndicProcessor
|
|
|
|
|
|
BATCH_SIZE = 4
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
quantization = None
|
|
|
|
def initialize_model_and_tokenizer(ckpt_dir, quantization):
|
|
"""Initialize the model and tokenizer with optional quantization."""
|
|
if quantization == "4-bit":
|
|
qconfig = BitsAndBytesConfig(
|
|
load_in_4bit=True,
|
|
bnb_4bit_use_double_quant=True,
|
|
bnb_4bit_compute_dtype=torch.bfloat16,
|
|
)
|
|
elif quantization == "8-bit":
|
|
qconfig = BitsAndBytesConfig(
|
|
load_in_8bit=True,
|
|
bnb_8bit_use_double_quant=True,
|
|
bnb_8bit_compute_dtype=torch.bfloat16,
|
|
)
|
|
else:
|
|
qconfig = None
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(ckpt_dir, trust_remote_code=True)
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(
|
|
ckpt_dir,
|
|
trust_remote_code=True,
|
|
low_cpu_mem_usage=True,
|
|
quantization_config=qconfig,
|
|
)
|
|
|
|
if qconfig is None:
|
|
model = model.to(DEVICE)
|
|
if DEVICE == "cuda":
|
|
model.half()
|
|
|
|
model.eval()
|
|
return tokenizer, model
|
|
|
|
|
|
def batch_translate(input_sentences, src_lang, tgt_lang, model, tokenizer, ip):
|
|
"""Batch translate sentences from src_lang to tgt_lang."""
|
|
translations = []
|
|
|
|
for i in range(0, len(input_sentences), BATCH_SIZE):
|
|
batch = input_sentences[i : i + BATCH_SIZE]
|
|
|
|
|
|
batch = ip.preprocess_batch(batch, src_lang=src_lang, tgt_lang=tgt_lang)
|
|
|
|
|
|
inputs = tokenizer(
|
|
batch,
|
|
truncation=True,
|
|
padding="longest",
|
|
return_tensors="pt",
|
|
return_attention_mask=True,
|
|
).to(DEVICE)
|
|
|
|
|
|
with torch.no_grad():
|
|
generated_tokens = model.generate(
|
|
**inputs,
|
|
use_cache=True,
|
|
min_length=0,
|
|
max_length=256,
|
|
num_beams=5,
|
|
num_return_sequences=1,
|
|
)
|
|
|
|
|
|
with tokenizer.as_target_tokenizer():
|
|
generated_tokens = tokenizer.batch_decode(
|
|
generated_tokens.detach().cpu().tolist(),
|
|
skip_special_tokens=True,
|
|
clean_up_tokenization_spaces=True,
|
|
)
|
|
|
|
|
|
translations += ip.postprocess_batch(generated_tokens, lang=tgt_lang)
|
|
|
|
del inputs
|
|
torch.cuda.empty_cache()
|
|
|
|
return translations
|
|
|
|
|
|
|
|
en_indic_ckpt_dir = "ai4bharat/indictrans2-indic-en-1B"
|
|
en_indic_tokenizer, en_indic_model = initialize_model_and_tokenizer(en_indic_ckpt_dir, quantization)
|
|
ip = IndicProcessor(inference=True)
|
|
|
|
|
|
en_sents = [
|
|
"""ഹലോ ഫ്രണ്ട്സ് കോളേജ് സ്കൂളിൻറെ മറ്റൊരു അധ്യായത്തിലേക്ക് ഏവർക്കും സ്വാഗതം ഇന്ന് ഞാൻ വന്നിരിക്കുന്നത് ചെറിയ കുട്ടികൾക്കായുള്ള ഒരു മലയാളം പ്രസംഗവും ആയിട്ടാണ് പ്രസംഗ വിഷയം ഇന്ത്യ എൻറെ രാജ്യം ആയിരക്കണക്കിന് വർഷങ്ങളുടെ പാരമ്പര്യം പേറുന്ന മഹത്തായ രാജ്യമാണ് ഇന്ത്യ 1947 ൽ ബ്രിട്ടീഷുകാരിൽ നിന്നും സ്വാതന്ത്ര്യം നേടിയ നമ്മുടെ ഭാരതം അനേകം നാട്ടുരാജ്യങ്ങൾ ചേർന്ന് ഏറ്റവും വലിയ ജനാധിപത്യ രാജ്യമായി ആശയുടെ അടിസ്ഥാനത്തിൽ നല്ല ഭരണത്തിന് സഹായകമാകും വിധം സംസ്ഥാനങ്ങൾ രൂപം കൊണ്ടും എന്ന് 28 സംസ്ഥാനങ്ങൾ ആണ് ഇന്ത്യയിൽ ഉള്ളത് നാനാത്വത്തിലെ ഏകത്വം എന്ന ചിന്ത വിവിധ ഭാഷകളും ജാതികളും മതങ്ങളും ആചാരങ്ങളും ജീവിതരീതികളും ഉള്ള ഒരു വലിയ ജനതയെ ഒറ്റക്കെട്ടായി നിർത്തുന്നു അതാണ് ഭാരതത്തിൻറെ വിജയം നേടിയ ലോകമേ തറവാട് എന്നതാണ് ഭാരത സംസ്കാരം അതുകൊണ്ട് തന്നെ ഇന്ത്യക്കാരെ മാത്രമല്ല ലോകം മുഴുവനും ഉള്ള എല്ലാവരെയും ഭാരതം സന്തോഷത്തോടെ ഉൾക്കൊള്ളുകയും സ്നേഹിക്കുകയും ചെയ്യുന്ന പ്രസിഡണ്ടും പ്രധാനമന്ത്രിയും മന്ത്രിമാരും ചേർന്ന് നമ്മുടെ രാജ്യം ഭരിക്കുന്നു മുഖ്യമന്ത്രിയും മന്ത്രിമാരും ചേർന്ന് സംസ്ഥാനങ്ങളെയും പരിപാലിക്കുന്നു എൻറെ ഇന്ത്യ അഭിമാനമാണ് സംസ്കാരങ്ങൾ ചേർന്ന് മനോഹരിയായി പുഞ്ചിരിക്കുന്ന എൻറെ അമ്മ ഭാരതമെന്നു കേട്ടാൽ തിളക്കണം ചോര നമുക്ക് ഞരമ്പുകളിൽ"""
|
|
]
|
|
|
|
|
|
src_lang, tgt_lang = "mal_Mlym", "eng_Latn"
|
|
hi_translations = batch_translate(en_sents, src_lang, tgt_lang, en_indic_model, en_indic_tokenizer, ip)
|
|
|
|
|
|
print(f"\n{src_lang} - {tgt_lang}")
|
|
for input_sentence, translation in zip(en_sents, hi_translations):
|
|
print(f"{src_lang}: {input_sentence}")
|
|
print(f"{tgt_lang}: {translation}")
|
|
|
|
|
|
del en_indic_tokenizer, en_indic_model
|
|
torch.cuda.empty_cache()
|
|
|