Spaces:
Runtime error
Runtime error
Add `accelerate`
Browse files- translate.py +11 -5
translate.py
CHANGED
|
@@ -6,9 +6,15 @@ import numpy as np
|
|
| 6 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 7 |
import torch
|
| 8 |
|
|
|
|
|
|
|
|
|
|
| 9 |
tokenizer = AutoTokenizer.from_pretrained("under-tree/transformer-en-ru")
|
| 10 |
model = AutoModelForSeq2SeqLM.from_pretrained("under-tree/transformer-en-ru")
|
| 11 |
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
@dataclass
|
| 14 |
class TranslationResult:
|
|
@@ -24,15 +30,15 @@ class TranslationResult:
|
|
| 24 |
|
| 25 |
def translator_fn(input_text: str, k=10) -> TranslationResult:
|
| 26 |
# Preprocess input
|
| 27 |
-
inputs = tokenizer(input_text, return_tensors="pt")
|
| 28 |
input_tokens = tokenizer.batch_decode(inputs.input_ids[0])
|
| 29 |
-
input_special_mask = torch.tensor([1 if t in tokenizer.all_special_tokens else 0 for t in input_tokens])
|
| 30 |
|
| 31 |
# Generate output
|
| 32 |
outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True, output_attentions=True)
|
| 33 |
output_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
|
| 34 |
output_tokens = tokenizer.batch_decode(outputs.sequences[0])
|
| 35 |
-
output_special_mask = torch.tensor([1 if t in tokenizer.all_special_tokens else 0 for t in output_tokens])
|
| 36 |
|
| 37 |
# Get cross attention matrix
|
| 38 |
cross_attention = torch.stack([torch.stack(t) for t in outputs.cross_attentions])
|
|
@@ -61,8 +67,8 @@ def translator_fn(input_text: str, k=10) -> TranslationResult:
|
|
| 61 |
clean_output_tokens = [t for t, m in zip(output_tokens, output_special_mask) if m == 0]
|
| 62 |
clean_input_tokens = [t for t, m in zip(input_tokens, input_special_mask) if m == 0]
|
| 63 |
clean_attention_matrix = attention_matrix[:len_output, :len_input] # for padding
|
| 64 |
-
clean_attention_matrix = np.delete(clean_attention_matrix, np.where(output_special_mask == 1), axis=0)
|
| 65 |
-
clean_attention_matrix = np.delete(clean_attention_matrix, np.where(input_special_mask == 1), axis=1)
|
| 66 |
|
| 67 |
n_input = len(clean_input_tokens)
|
| 68 |
n_output = len(clean_output_tokens)
|
|
|
|
| 6 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 7 |
import torch
|
| 8 |
|
| 9 |
+
from accelerate import Accelerator
|
| 10 |
+
accelerator = Accelerator()
|
| 11 |
+
|
| 12 |
tokenizer = AutoTokenizer.from_pretrained("under-tree/transformer-en-ru")
|
| 13 |
model = AutoModelForSeq2SeqLM.from_pretrained("under-tree/transformer-en-ru")
|
| 14 |
|
| 15 |
+
device = accelerator.device
|
| 16 |
+
|
| 17 |
+
model = accelerator.prepare(model)
|
| 18 |
|
| 19 |
@dataclass
|
| 20 |
class TranslationResult:
|
|
|
|
| 30 |
|
| 31 |
def translator_fn(input_text: str, k=10) -> TranslationResult:
|
| 32 |
# Preprocess input
|
| 33 |
+
inputs = tokenizer(input_text, return_tensors="pt").to(device)
|
| 34 |
input_tokens = tokenizer.batch_decode(inputs.input_ids[0])
|
| 35 |
+
input_special_mask = torch.tensor([1 if t in tokenizer.all_special_tokens else 0 for t in input_tokens]).to(device)
|
| 36 |
|
| 37 |
# Generate output
|
| 38 |
outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True, output_attentions=True)
|
| 39 |
output_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
|
| 40 |
output_tokens = tokenizer.batch_decode(outputs.sequences[0])
|
| 41 |
+
output_special_mask = torch.tensor([1 if t in tokenizer.all_special_tokens else 0 for t in output_tokens]).to(device)
|
| 42 |
|
| 43 |
# Get cross attention matrix
|
| 44 |
cross_attention = torch.stack([torch.stack(t) for t in outputs.cross_attentions])
|
|
|
|
| 67 |
clean_output_tokens = [t for t, m in zip(output_tokens, output_special_mask) if m == 0]
|
| 68 |
clean_input_tokens = [t for t, m in zip(input_tokens, input_special_mask) if m == 0]
|
| 69 |
clean_attention_matrix = attention_matrix[:len_output, :len_input] # for padding
|
| 70 |
+
clean_attention_matrix = np.delete(clean_attention_matrix, np.where(output_special_mask.detach().cpu().numpy() == 1), axis=0)
|
| 71 |
+
clean_attention_matrix = np.delete(clean_attention_matrix, np.where(input_special_mask.detach().cpu().numpy() == 1), axis=1)
|
| 72 |
|
| 73 |
n_input = len(clean_input_tokens)
|
| 74 |
n_output = len(clean_output_tokens)
|