File size: 2,967 Bytes
3bd5293 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import os
import pandas as pd
import torch
import argparse
import shutil
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorWithPadding
from torch.utils.data import Dataset, DataLoader
parser = argparse.ArgumentParser()
parser.add_argument("-f", "--filename", type=str)
parser.add_argument(
"data_folder",
nargs="?",
type=str,
default="data_multi",
)
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs("data_translated", exist_ok=True)
if args.filename == "015_ccs_synthetic_en.feather":
# No need to translate English -> English
shutil.copy2(os.path.join(args.data_folder, "015_ccs_synthetic_en.feather"), "data_translated")
os._exit(0)
df = pd.read_feather(os.path.join(args.data_folder, args.filename))
df["opus_mt_url"] = df["opus_mt_url"].str.replace("https://huggingface.co/", "")
print(f"Starting translation of English to {df['multi_language_name'][0]}")
class CaptionDataset(Dataset):
def __init__(self, df, tokenizer_name):
self.df = df
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
def __len__(self):
return len(self.df)
def __getitem__(self, index):
sentence1 = df.loc[index, "caption"]
tokens = self.tokenizer(sentence1, return_tensors="pt")
return tokens
tokenizer = AutoTokenizer.from_pretrained(df["opus_mt_url"][0])
model = AutoModelForSeq2SeqLM.from_pretrained(df["opus_mt_url"][0])
model.to(device)
model.eval()
def custom_collate_fn(data):
"""
Data collator with padding.
"""
tokens = [sample["input_ids"][0] for sample in data]
attention_masks = [sample["attention_mask"][0] for sample in data]
attention_masks = torch.nn.utils.rnn.pad_sequence(attention_masks, batch_first=True)
padded_tokens = torch.nn.utils.rnn.pad_sequence(tokens, batch_first=True)
batch = {"input_ids": padded_tokens, "attention_mask": attention_masks}
return batch
if df["multi_target"][0] == 1:
# If model is a multilingual model we need to concatenate target language code
# in the form '>> CODE >>' in front of string so model outputs correct language.
df["caption"] = ">>" + df["target_code"] + "<<" + df["caption"]
test_data = CaptionDataset(df, df["opus_mt_url"][0])
test_dataloader = DataLoader(
test_data,
batch_size=50,
shuffle=False,
num_workers=4,
collate_fn=custom_collate_fn,
)
with torch.no_grad():
decoded_tokens = []
for i, batch in enumerate(tqdm(test_dataloader)):
batch = {k: v.to(device) for k, v in batch.items()}
output_tokens = model.generate(**batch)
decoded_tokens += tokenizer.batch_decode(output_tokens.to("cpu"), skip_special_tokens=True)
df["caption_multi"] = decoded_tokens
df.to_feather(os.path.join("data_translated", args.filename))
print(f"Finished translating English to {df['multi_language_name'][0]}")
|