nitikdias's picture
Upload 114 files
74ee63f verified
import torch
from transformers import AutoModelForSeq2SeqLM, BitsAndBytesConfig, AutoTokenizer
from IndicTransToolkit import IndicProcessor
# Constants
BATCH_SIZE = 4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
quantization = None
def initialize_model_and_tokenizer(ckpt_dir, quantization):
"""Initialize the model and tokenizer with optional quantization."""
if quantization == "4-bit":
qconfig = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16,
)
elif quantization == "8-bit":
qconfig = BitsAndBytesConfig(
load_in_8bit=True,
bnb_8bit_use_double_quant=True,
bnb_8bit_compute_dtype=torch.bfloat16,
)
else:
qconfig = None
tokenizer = AutoTokenizer.from_pretrained(ckpt_dir, trust_remote_code=True)
model = AutoModelForSeq2SeqLM.from_pretrained(
ckpt_dir,
trust_remote_code=True,
low_cpu_mem_usage=True,
quantization_config=qconfig,
)
if qconfig is None:
model = model.to(DEVICE)
if DEVICE == "cuda":
model.half()
model.eval()
return tokenizer, model
def batch_translate(input_sentences, src_lang, tgt_lang, model, tokenizer, ip):
"""Batch translate sentences from src_lang to tgt_lang."""
translations = []
for i in range(0, len(input_sentences), BATCH_SIZE):
batch = input_sentences[i : i + BATCH_SIZE]
# Preprocess the batch and extract entity mappings
batch = ip.preprocess_batch(batch, src_lang=src_lang, tgt_lang=tgt_lang)
# Tokenize the batch and generate input encodings
inputs = tokenizer(
batch,
truncation=True,
padding="longest",
return_tensors="pt",
return_attention_mask=True,
).to(DEVICE)
# Generate translations using the model
with torch.no_grad():
generated_tokens = model.generate(
**inputs,
use_cache=True,
min_length=0,
max_length=256,
num_beams=5,
num_return_sequences=1,
)
# Decode the generated tokens into text
with tokenizer.as_target_tokenizer():
generated_tokens = tokenizer.batch_decode(
generated_tokens.detach().cpu().tolist(),
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
# Postprocess the translations, including entity replacement
translations += ip.postprocess_batch(generated_tokens, lang=tgt_lang)
del inputs
torch.cuda.empty_cache()
return translations
# Initialize the model and processor
en_indic_ckpt_dir = "ai4bharat/indictrans2-indic-en-1B"
en_indic_tokenizer, en_indic_model = initialize_model_and_tokenizer(en_indic_ckpt_dir, quantization)
ip = IndicProcessor(inference=True)
# Sample sentences
en_sents = [
"""ഹലോ ഫ്രണ്ട്സ് കോളേജ് സ്കൂളിൻറെ മറ്റൊരു അധ്യായത്തിലേക്ക് ഏവർക്കും സ്വാഗതം ഇന്ന് ഞാൻ വന്നിരിക്കുന്നത് ചെറിയ കുട്ടികൾക്കായുള്ള ഒരു മലയാളം പ്രസംഗവും ആയിട്ടാണ് പ്രസംഗ വിഷയം ഇന്ത്യ എൻറെ രാജ്യം ആയിരക്കണക്കിന് വർഷങ്ങളുടെ പാരമ്പര്യം പേറുന്ന മഹത്തായ രാജ്യമാണ് ഇന്ത്യ 1947 ൽ ബ്രിട്ടീഷുകാരിൽ നിന്നും സ്വാതന്ത്ര്യം നേടിയ നമ്മുടെ ഭാരതം അനേകം നാട്ടുരാജ്യങ്ങൾ ചേർന്ന് ഏറ്റവും വലിയ ജനാധിപത്യ രാജ്യമായി ആശയുടെ അടിസ്ഥാനത്തിൽ നല്ല ഭരണത്തിന് സഹായകമാകും വിധം സംസ്ഥാനങ്ങൾ രൂപം കൊണ്ടും എന്ന് 28 സംസ്ഥാനങ്ങൾ ആണ് ഇന്ത്യയിൽ ഉള്ളത് നാനാത്വത്തിലെ ഏകത്വം എന്ന ചിന്ത വിവിധ ഭാഷകളും ജാതികളും മതങ്ങളും ആചാരങ്ങളും ജീവിതരീതികളും ഉള്ള ഒരു വലിയ ജനതയെ ഒറ്റക്കെട്ടായി നിർത്തുന്നു അതാണ് ഭാരതത്തിൻറെ വിജയം നേടിയ ലോകമേ തറവാട് എന്നതാണ് ഭാരത സംസ്കാരം അതുകൊണ്ട് തന്നെ ഇന്ത്യക്കാരെ മാത്രമല്ല ലോകം മുഴുവനും ഉള്ള എല്ലാവരെയും ഭാരതം സന്തോഷത്തോടെ ഉൾക്കൊള്ളുകയും സ്നേഹിക്കുകയും ചെയ്യുന്ന പ്രസിഡണ്ടും പ്രധാനമന്ത്രിയും മന്ത്രിമാരും ചേർന്ന് നമ്മുടെ രാജ്യം ഭരിക്കുന്നു മുഖ്യമന്ത്രിയും മന്ത്രിമാരും ചേർന്ന് സംസ്ഥാനങ്ങളെയും പരിപാലിക്കുന്നു എൻറെ ഇന്ത്യ അഭിമാനമാണ് സംസ്കാരങ്ങൾ ചേർന്ന് മനോഹരിയായി പുഞ്ചിരിക്കുന്ന എൻറെ അമ്മ ഭാരതമെന്നു കേട്ടാൽ തിളക്കണം ചോര നമുക്ക് ഞരമ്പുകളിൽ"""
]
# Translation
src_lang, tgt_lang = "mal_Mlym", "eng_Latn"
hi_translations = batch_translate(en_sents, src_lang, tgt_lang, en_indic_model, en_indic_tokenizer, ip)
# Print translations
print(f"\n{src_lang} - {tgt_lang}")
for input_sentence, translation in zip(en_sents, hi_translations):
print(f"{src_lang}: {input_sentence}")
print(f"{tgt_lang}: {translation}")
# Free GPU memory
del en_indic_tokenizer, en_indic_model
torch.cuda.empty_cache()