tirthadagr8 commited on
Commit
a470f10
·
verified ·
1 Parent(s): 8a10b98

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +115 -3
README.md CHANGED
@@ -1,3 +1,115 @@
1
- ---
2
- license: unknown
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ```python
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
5
+
6
+ class TextRefinementModel(nn.Module):
7
+ def __init__(self, model_name='tirthadagr8/custom-mbart-large-50', max_length=64):
8
+ super(TextRefinementModel, self).__init__()
9
+ self.tokenizer = MBart50TokenizerFast.from_pretrained(model_name)
10
+ self.mbart = MBartForConditionalGeneration.from_pretrained(model_name)
11
+ self.mbart.config.max_length=64
12
+ self.max_length = max_length
13
+
14
+ # Set the language code for Japanese (ja_XX) or Chinese (zh_CN)
15
+ # self.tokenizer.src_lang = 'ja_XX' # For Japanese
16
+ # self.tokenizer.src_lang = 'zh_CN' # Uncomment for Chinese
17
+
18
+ def forward(self, input_texts):
19
+ # Tokenize the noisy text inputs
20
+ input_ids = self.tokenizer(input_texts, return_tensors='pt', padding=True, truncation=True, max_length=self.max_length)['input_ids']
21
+
22
+ # mBART generates output logits
23
+ output_logits = self.mbart(input_ids).logits
24
+
25
+ return output_logits
26
+
27
+ def generate_corrected_text(self, input_texts, temperature=0.7):
28
+ # Tokenize the input noisy text
29
+ input_ids = self.tokenizer(input_texts, return_tensors='pt', padding=True, truncation=True, max_length=self.max_length)['input_ids']
30
+
31
+ # Generate corrected text using mBART's generate function
32
+ mbart_outputs = self.mbart.generate(input_ids, max_length=self.max_length, temperature=temperature, num_return_sequences=1)
33
+
34
+ # Decode generated text
35
+ corrected_texts = [self.tokenizer.decode(g, skip_special_tokens=True) for g in mbart_outputs]
36
+ return corrected_texts
37
+
38
+ # Example usage
39
+ model = TextRefinementModel()
40
+
41
+ noisy_text = ["これは間違ったテキストの例です。", "这是错误的文本示例。"] # Japanese and Chinese examples
42
+ corrected_text = model.generate_corrected_text(noisy_text)
43
+
44
+ print(f"Corrected Text: {corrected_text}")
45
+ ```
46
+ For training:
47
+ ```python
48
+ from transformers import AdamW
49
+ import torch.nn.functional as F
50
+ from tqdm import tqdm
51
+ from torch.utils.data import DataLoader
52
+ import numpy as np
53
+
54
+ # Initialize the mBART model and optimizer
55
+ model = TextRefinementModel().cuda()
56
+ optimizer = AdamW(model.parameters(), lr=5e-5)
57
+
58
+ batch_size = 16
59
+
60
+ # Create a custom dataset class
61
+ class TextCorrectionDataset(torch.utils.data.Dataset):
62
+ def __init__(self, data, tokenizer, max_length=64):
63
+ self.data = data
64
+ self.tokenizer = tokenizer
65
+ self.max_length = max_length
66
+
67
+ def __len__(self):
68
+ return len(self.data)
69
+
70
+ def __getitem__(self, idx):
71
+ noisy_text, correct_text = self.data[idx]
72
+ inputs = self.tokenizer(noisy_text, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
73
+ labels = self.tokenizer(correct_text, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
74
+
75
+ # Adjust label tensors for correct shape
76
+ input_ids = inputs['input_ids'].squeeze() # Remove extra batch dimension
77
+ labels = labels['input_ids'].squeeze() # Same for labels
78
+ return input_ids, labels
79
+
80
+ # Create DataLoader with batching
81
+ train_dataset = TextCorrectionDataset(train_data, model.tokenizer)
82
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
83
+
84
+ # Define training loop with batches
85
+ def train_epoch(model, train_loader, optimizer):
86
+ model.train()
87
+ total_loss = []
88
+ step_iter=0
89
+ for input_ids, labels in tqdm(train_loader):
90
+ # Move tensors to model's device
91
+ input_ids = input_ids.to(model.mbart.device)
92
+ labels = labels.to(model.mbart.device)
93
+
94
+ # Forward pass
95
+ outputs = model.mbart(input_ids=input_ids, labels=labels)
96
+ loss = outputs.loss
97
+
98
+ # Backpropagation
99
+ optimizer.zero_grad()
100
+ loss.backward()
101
+ optimizer.step()
102
+
103
+ total_loss.append(loss.item())
104
+
105
+ if step_iter%100==0:
106
+ print('Loss:',np.mean(total_loss))
107
+
108
+ step_iter+=1
109
+ return np.mean(total_loss)
110
+
111
+ # Example training loop
112
+ for epoch in range(5): # Train for 5 epochs (or as needed)
113
+ loss = train_epoch(model, train_loader, optimizer)
114
+ print(f"Epoch {epoch+1}, Loss: {loss:.4f}")
115
+ ```