--- license: unknown --- ```python import torch import torch.nn as nn from transformers import MBartForConditionalGeneration, MBart50TokenizerFast class TextRefinementModel(nn.Module): def __init__(self, model_name='tirthadagr8/custom-mbart-large-50', max_length=64): super(TextRefinementModel, self).__init__() self.tokenizer = MBart50TokenizerFast.from_pretrained(model_name) self.mbart = MBartForConditionalGeneration.from_pretrained(model_name) self.mbart.config.max_length=64 self.max_length = max_length # Set the language code for Japanese (ja_XX) or Chinese (zh_CN) # self.tokenizer.src_lang = 'ja_XX' # For Japanese # self.tokenizer.src_lang = 'zh_CN' # Uncomment for Chinese def forward(self, input_texts): # Tokenize the noisy text inputs input_ids = self.tokenizer(input_texts, return_tensors='pt', padding=True, truncation=True, max_length=self.max_length)['input_ids'] # mBART generates output logits output_logits = self.mbart(input_ids).logits return output_logits def generate_corrected_text(self, input_texts, temperature=0.7): # Tokenize the input noisy text input_ids = self.tokenizer(input_texts, return_tensors='pt', padding=True, truncation=True, max_length=self.max_length)['input_ids'] # Generate corrected text using mBART's generate function mbart_outputs = self.mbart.generate(input_ids, max_length=self.max_length, temperature=temperature, num_return_sequences=1) # Decode generated text corrected_texts = [self.tokenizer.decode(g, skip_special_tokens=True) for g in mbart_outputs] return corrected_texts # Example usage model = TextRefinementModel() noisy_text = ["これは間違ったテキストの例です。", "这是错误的文本示例。"] # Japanese and Chinese examples corrected_text = model.generate_corrected_text(noisy_text) print(f"Corrected Text: {corrected_text}") ``` For training: ```python from transformers import AdamW import torch.nn.functional as F from tqdm import tqdm from torch.utils.data import DataLoader import numpy as np # Initialize the mBART model and optimizer model = TextRefinementModel().cuda() optimizer = AdamW(model.parameters(), lr=5e-5) batch_size = 16 # Create a custom dataset class class TextCorrectionDataset(torch.utils.data.Dataset): def __init__(self, data, tokenizer, max_length=64): self.data = data self.tokenizer = tokenizer self.max_length = max_length def __len__(self): return len(self.data) def __getitem__(self, idx): noisy_text, correct_text = self.data[idx] inputs = self.tokenizer(noisy_text, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length) labels = self.tokenizer(correct_text, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length) # Adjust label tensors for correct shape input_ids = inputs['input_ids'].squeeze() # Remove extra batch dimension labels = labels['input_ids'].squeeze() # Same for labels return input_ids, labels # Create DataLoader with batching train_dataset = TextCorrectionDataset(train_data, model.tokenizer) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) # Define training loop with batches def train_epoch(model, train_loader, optimizer): model.train() total_loss = [] step_iter=0 for input_ids, labels in tqdm(train_loader): # Move tensors to model's device input_ids = input_ids.to(model.mbart.device) labels = labels.to(model.mbart.device) # Forward pass outputs = model.mbart(input_ids=input_ids, labels=labels) loss = outputs.loss # Backpropagation optimizer.zero_grad() loss.backward() optimizer.step() total_loss.append(loss.item()) if step_iter%100==0: print('Loss:',np.mean(total_loss)) step_iter+=1 return np.mean(total_loss) # Example training loop for epoch in range(5): # Train for 5 epochs (or as needed) loss = train_epoch(model, train_loader, optimizer) print(f"Epoch {epoch+1}, Loss: {loss:.4f}") ```