Move tensor to device in greedy translation function
Browse files
app.py
CHANGED
@@ -8,8 +8,9 @@ def en_translate_ar_beam(text, model, tokenizer, max_tries=50):
|
|
8 |
return "future work"
|
9 |
|
10 |
|
|
|
11 |
def en_translate_ar_greedy(text, model, tokenizer, max_tries=50):
|
12 |
-
source_tensor = torch.tensor(tokenizer(text)).unsqueeze(0)
|
13 |
target_tokens = greedy_decode(model, source_tensor,
|
14 |
tokenizer.get_tokenId('<s>'),
|
15 |
tokenizer.get_tokenId('</s>'),
|
@@ -18,7 +19,6 @@ def en_translate_ar_greedy(text, model, tokenizer, max_tries=50):
|
|
18 |
return tokenizer.decode(target_tokens)
|
19 |
|
20 |
|
21 |
-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
22 |
tokenizer = Callable_tokenizer('./assets/tokenizers/en-ar_tokenizer.model')
|
23 |
|
24 |
model_state_dict = torch.load("./assets/models/en-ar_s2sAttention.pth", map_location=device, weights_only=True)['model_state_dict']
|
|
|
8 |
return "future work"
|
9 |
|
10 |
|
11 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
12 |
def en_translate_ar_greedy(text, model, tokenizer, max_tries=50):
|
13 |
+
source_tensor = torch.tensor(tokenizer(text)).unsqueeze(0).to(device)
|
14 |
target_tokens = greedy_decode(model, source_tensor,
|
15 |
tokenizer.get_tokenId('<s>'),
|
16 |
tokenizer.get_tokenId('</s>'),
|
|
|
19 |
return tokenizer.decode(target_tokens)
|
20 |
|
21 |
|
|
|
22 |
tokenizer = Callable_tokenizer('./assets/tokenizers/en-ar_tokenizer.model')
|
23 |
|
24 |
model_state_dict = torch.load("./assets/models/en-ar_s2sAttention.pth", map_location=device, weights_only=True)['model_state_dict']
|