Spaces:
Sleeping
Sleeping
Add `accelerate`
Browse files- translate.py +11 -5
translate.py
CHANGED
@@ -6,9 +6,15 @@ import numpy as np
|
|
6 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
7 |
import torch
|
8 |
|
|
|
|
|
|
|
9 |
tokenizer = AutoTokenizer.from_pretrained("under-tree/transformer-en-ru")
|
10 |
model = AutoModelForSeq2SeqLM.from_pretrained("under-tree/transformer-en-ru")
|
11 |
|
|
|
|
|
|
|
12 |
|
13 |
@dataclass
|
14 |
class TranslationResult:
|
@@ -24,15 +30,15 @@ class TranslationResult:
|
|
24 |
|
25 |
def translator_fn(input_text: str, k=10) -> TranslationResult:
|
26 |
# Preprocess input
|
27 |
-
inputs = tokenizer(input_text, return_tensors="pt")
|
28 |
input_tokens = tokenizer.batch_decode(inputs.input_ids[0])
|
29 |
-
input_special_mask = torch.tensor([1 if t in tokenizer.all_special_tokens else 0 for t in input_tokens])
|
30 |
|
31 |
# Generate output
|
32 |
outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True, output_attentions=True)
|
33 |
output_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
|
34 |
output_tokens = tokenizer.batch_decode(outputs.sequences[0])
|
35 |
-
output_special_mask = torch.tensor([1 if t in tokenizer.all_special_tokens else 0 for t in output_tokens])
|
36 |
|
37 |
# Get cross attention matrix
|
38 |
cross_attention = torch.stack([torch.stack(t) for t in outputs.cross_attentions])
|
@@ -61,8 +67,8 @@ def translator_fn(input_text: str, k=10) -> TranslationResult:
|
|
61 |
clean_output_tokens = [t for t, m in zip(output_tokens, output_special_mask) if m == 0]
|
62 |
clean_input_tokens = [t for t, m in zip(input_tokens, input_special_mask) if m == 0]
|
63 |
clean_attention_matrix = attention_matrix[:len_output, :len_input] # for padding
|
64 |
-
clean_attention_matrix = np.delete(clean_attention_matrix, np.where(output_special_mask == 1), axis=0)
|
65 |
-
clean_attention_matrix = np.delete(clean_attention_matrix, np.where(input_special_mask == 1), axis=1)
|
66 |
|
67 |
n_input = len(clean_input_tokens)
|
68 |
n_output = len(clean_output_tokens)
|
|
|
6 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
7 |
import torch
|
8 |
|
9 |
+
from accelerate import Accelerator
|
10 |
+
accelerator = Accelerator()
|
11 |
+
|
12 |
tokenizer = AutoTokenizer.from_pretrained("under-tree/transformer-en-ru")
|
13 |
model = AutoModelForSeq2SeqLM.from_pretrained("under-tree/transformer-en-ru")
|
14 |
|
15 |
+
device = accelerator.device
|
16 |
+
|
17 |
+
model = accelerator.prepare(model)
|
18 |
|
19 |
@dataclass
|
20 |
class TranslationResult:
|
|
|
30 |
|
31 |
def translator_fn(input_text: str, k=10) -> TranslationResult:
|
32 |
# Preprocess input
|
33 |
+
inputs = tokenizer(input_text, return_tensors="pt").to(device)
|
34 |
input_tokens = tokenizer.batch_decode(inputs.input_ids[0])
|
35 |
+
input_special_mask = torch.tensor([1 if t in tokenizer.all_special_tokens else 0 for t in input_tokens]).to(device)
|
36 |
|
37 |
# Generate output
|
38 |
outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True, output_attentions=True)
|
39 |
output_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
|
40 |
output_tokens = tokenizer.batch_decode(outputs.sequences[0])
|
41 |
+
output_special_mask = torch.tensor([1 if t in tokenizer.all_special_tokens else 0 for t in output_tokens]).to(device)
|
42 |
|
43 |
# Get cross attention matrix
|
44 |
cross_attention = torch.stack([torch.stack(t) for t in outputs.cross_attentions])
|
|
|
67 |
clean_output_tokens = [t for t, m in zip(output_tokens, output_special_mask) if m == 0]
|
68 |
clean_input_tokens = [t for t, m in zip(input_tokens, input_special_mask) if m == 0]
|
69 |
clean_attention_matrix = attention_matrix[:len_output, :len_input] # for padding
|
70 |
+
clean_attention_matrix = np.delete(clean_attention_matrix, np.where(output_special_mask.detach().cpu().numpy() == 1), axis=0)
|
71 |
+
clean_attention_matrix = np.delete(clean_attention_matrix, np.where(input_special_mask.detach().cpu().numpy() == 1), axis=1)
|
72 |
|
73 |
n_input = len(clean_input_tokens)
|
74 |
n_output = len(clean_output_tokens)
|