|
--- |
|
license: unknown |
|
--- |
|
```python |
|
import torch |
|
import torch.nn as nn |
|
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast |
|
|
|
class TextRefinementModel(nn.Module): |
|
def __init__(self, model_name='tirthadagr8/custom-mbart-large-50', max_length=64): |
|
super(TextRefinementModel, self).__init__() |
|
self.tokenizer = MBart50TokenizerFast.from_pretrained(model_name) |
|
self.mbart = MBartForConditionalGeneration.from_pretrained(model_name) |
|
self.mbart.config.max_length=64 |
|
self.max_length = max_length |
|
|
|
# Set the language code for Japanese (ja_XX) or Chinese (zh_CN) |
|
# self.tokenizer.src_lang = 'ja_XX' # For Japanese |
|
# self.tokenizer.src_lang = 'zh_CN' # Uncomment for Chinese |
|
|
|
def forward(self, input_texts): |
|
# Tokenize the noisy text inputs |
|
input_ids = self.tokenizer(input_texts, return_tensors='pt', padding=True, truncation=True, max_length=self.max_length)['input_ids'] |
|
|
|
# mBART generates output logits |
|
output_logits = self.mbart(input_ids).logits |
|
|
|
return output_logits |
|
|
|
def generate_corrected_text(self, input_texts, temperature=0.7): |
|
# Tokenize the input noisy text |
|
input_ids = self.tokenizer(input_texts, return_tensors='pt', padding=True, truncation=True, max_length=self.max_length)['input_ids'] |
|
|
|
# Generate corrected text using mBART's generate function |
|
mbart_outputs = self.mbart.generate(input_ids, max_length=self.max_length, temperature=temperature, num_return_sequences=1) |
|
|
|
# Decode generated text |
|
corrected_texts = [self.tokenizer.decode(g, skip_special_tokens=True) for g in mbart_outputs] |
|
return corrected_texts |
|
|
|
# Example usage |
|
model = TextRefinementModel() |
|
|
|
noisy_text = ["ใใใฏ้้ใฃใใใญในใใฎไพใงใใ", "่ฟๆฏ้่ฏฏ็ๆๆฌ็คบไพใ"] # Japanese and Chinese examples |
|
corrected_text = model.generate_corrected_text(noisy_text) |
|
|
|
print(f"Corrected Text: {corrected_text}") |
|
``` |
|
For training: |
|
```python |
|
from transformers import AdamW |
|
import torch.nn.functional as F |
|
from tqdm import tqdm |
|
from torch.utils.data import DataLoader |
|
import numpy as np |
|
|
|
# Initialize the mBART model and optimizer |
|
model = TextRefinementModel().cuda() |
|
optimizer = AdamW(model.parameters(), lr=5e-5) |
|
|
|
batch_size = 16 |
|
|
|
# Create a custom dataset class |
|
class TextCorrectionDataset(torch.utils.data.Dataset): |
|
def __init__(self, data, tokenizer, max_length=64): |
|
self.data = data |
|
self.tokenizer = tokenizer |
|
self.max_length = max_length |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
noisy_text, correct_text = self.data[idx] |
|
inputs = self.tokenizer(noisy_text, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length) |
|
labels = self.tokenizer(correct_text, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length) |
|
|
|
# Adjust label tensors for correct shape |
|
input_ids = inputs['input_ids'].squeeze() # Remove extra batch dimension |
|
labels = labels['input_ids'].squeeze() # Same for labels |
|
return input_ids, labels |
|
|
|
# Create DataLoader with batching |
|
train_dataset = TextCorrectionDataset(train_data, model.tokenizer) |
|
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) |
|
|
|
# Define training loop with batches |
|
def train_epoch(model, train_loader, optimizer): |
|
model.train() |
|
total_loss = [] |
|
step_iter=0 |
|
for input_ids, labels in tqdm(train_loader): |
|
# Move tensors to model's device |
|
input_ids = input_ids.to(model.mbart.device) |
|
labels = labels.to(model.mbart.device) |
|
|
|
# Forward pass |
|
outputs = model.mbart(input_ids=input_ids, labels=labels) |
|
loss = outputs.loss |
|
|
|
# Backpropagation |
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
|
|
total_loss.append(loss.item()) |
|
|
|
if step_iter%100==0: |
|
print('Loss:',np.mean(total_loss)) |
|
|
|
step_iter+=1 |
|
return np.mean(total_loss) |
|
|
|
# Example training loop |
|
for epoch in range(5): # Train for 5 epochs (or as needed) |
|
loss = train_epoch(model, train_loader, optimizer) |
|
print(f"Epoch {epoch+1}, Loss: {loss:.4f}") |
|
``` |