TheDemond commited on
Commit
6b6e8d3
·
1 Parent(s): 2c193e4

Move tensor to device in greedy translation function

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -8,8 +8,9 @@ def en_translate_ar_beam(text, model, tokenizer, max_tries=50):
8
  return "future work"
9
 
10
 
 
11
  def en_translate_ar_greedy(text, model, tokenizer, max_tries=50):
12
- source_tensor = torch.tensor(tokenizer(text)).unsqueeze(0)
13
  target_tokens = greedy_decode(model, source_tensor,
14
  tokenizer.get_tokenId('<s>'),
15
  tokenizer.get_tokenId('</s>'),
@@ -18,7 +19,6 @@ def en_translate_ar_greedy(text, model, tokenizer, max_tries=50):
18
  return tokenizer.decode(target_tokens)
19
 
20
 
21
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
22
  tokenizer = Callable_tokenizer('./assets/tokenizers/en-ar_tokenizer.model')
23
 
24
  model_state_dict = torch.load("./assets/models/en-ar_s2sAttention.pth", map_location=device, weights_only=True)['model_state_dict']
 
8
  return "future work"
9
 
10
 
11
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
12
  def en_translate_ar_greedy(text, model, tokenizer, max_tries=50):
13
+ source_tensor = torch.tensor(tokenizer(text)).unsqueeze(0).to(device)
14
  target_tokens = greedy_decode(model, source_tensor,
15
  tokenizer.get_tokenId('<s>'),
16
  tokenizer.get_tokenId('</s>'),
 
19
  return tokenizer.decode(target_tokens)
20
 
21
 
 
22
  tokenizer = Callable_tokenizer('./assets/tokenizers/en-ar_tokenizer.model')
23
 
24
  model_state_dict = torch.load("./assets/models/en-ar_s2sAttention.pth", map_location=device, weights_only=True)['model_state_dict']