waleko commited on
Commit
fbbcdd2
·
1 Parent(s): 8322ba1

Add `accelerate`

Browse files
Files changed (1) hide show
  1. translate.py +11 -5
translate.py CHANGED
@@ -6,9 +6,15 @@ import numpy as np
6
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
7
  import torch
8
 
 
 
 
9
  tokenizer = AutoTokenizer.from_pretrained("under-tree/transformer-en-ru")
10
  model = AutoModelForSeq2SeqLM.from_pretrained("under-tree/transformer-en-ru")
11
 
 
 
 
12
 
13
  @dataclass
14
  class TranslationResult:
@@ -24,15 +30,15 @@ class TranslationResult:
24
 
25
  def translator_fn(input_text: str, k=10) -> TranslationResult:
26
  # Preprocess input
27
- inputs = tokenizer(input_text, return_tensors="pt")
28
  input_tokens = tokenizer.batch_decode(inputs.input_ids[0])
29
- input_special_mask = torch.tensor([1 if t in tokenizer.all_special_tokens else 0 for t in input_tokens])
30
 
31
  # Generate output
32
  outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True, output_attentions=True)
33
  output_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
34
  output_tokens = tokenizer.batch_decode(outputs.sequences[0])
35
- output_special_mask = torch.tensor([1 if t in tokenizer.all_special_tokens else 0 for t in output_tokens])
36
 
37
  # Get cross attention matrix
38
  cross_attention = torch.stack([torch.stack(t) for t in outputs.cross_attentions])
@@ -61,8 +67,8 @@ def translator_fn(input_text: str, k=10) -> TranslationResult:
61
  clean_output_tokens = [t for t, m in zip(output_tokens, output_special_mask) if m == 0]
62
  clean_input_tokens = [t for t, m in zip(input_tokens, input_special_mask) if m == 0]
63
  clean_attention_matrix = attention_matrix[:len_output, :len_input] # for padding
64
- clean_attention_matrix = np.delete(clean_attention_matrix, np.where(output_special_mask == 1), axis=0)
65
- clean_attention_matrix = np.delete(clean_attention_matrix, np.where(input_special_mask == 1), axis=1)
66
 
67
  n_input = len(clean_input_tokens)
68
  n_output = len(clean_output_tokens)
 
6
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
7
  import torch
8
 
9
+ from accelerate import Accelerator
10
+ accelerator = Accelerator()
11
+
12
  tokenizer = AutoTokenizer.from_pretrained("under-tree/transformer-en-ru")
13
  model = AutoModelForSeq2SeqLM.from_pretrained("under-tree/transformer-en-ru")
14
 
15
+ device = accelerator.device
16
+
17
+ model = accelerator.prepare(model)
18
 
19
  @dataclass
20
  class TranslationResult:
 
30
 
31
  def translator_fn(input_text: str, k=10) -> TranslationResult:
32
  # Preprocess input
33
+ inputs = tokenizer(input_text, return_tensors="pt").to(device)
34
  input_tokens = tokenizer.batch_decode(inputs.input_ids[0])
35
+ input_special_mask = torch.tensor([1 if t in tokenizer.all_special_tokens else 0 for t in input_tokens]).to(device)
36
 
37
  # Generate output
38
  outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True, output_attentions=True)
39
  output_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
40
  output_tokens = tokenizer.batch_decode(outputs.sequences[0])
41
+ output_special_mask = torch.tensor([1 if t in tokenizer.all_special_tokens else 0 for t in output_tokens]).to(device)
42
 
43
  # Get cross attention matrix
44
  cross_attention = torch.stack([torch.stack(t) for t in outputs.cross_attentions])
 
67
  clean_output_tokens = [t for t, m in zip(output_tokens, output_special_mask) if m == 0]
68
  clean_input_tokens = [t for t, m in zip(input_tokens, input_special_mask) if m == 0]
69
  clean_attention_matrix = attention_matrix[:len_output, :len_input] # for padding
70
+ clean_attention_matrix = np.delete(clean_attention_matrix, np.where(output_special_mask.detach().cpu().numpy() == 1), axis=0)
71
+ clean_attention_matrix = np.delete(clean_attention_matrix, np.where(input_special_mask.detach().cpu().numpy() == 1), axis=1)
72
 
73
  n_input = len(clean_input_tokens)
74
  n_output = len(clean_output_tokens)