Iker commited on
Commit
236bc54
·
1 Parent(s): 136cf77

Move batch to device

Browse files
Files changed (1) hide show
  1. translate_troch2trt.py +7 -3
translate_troch2trt.py CHANGED
@@ -52,24 +52,28 @@ def main(
52
  raise ValueError("Precision must be 16, 32 or 64.")
53
 
54
  if tensorrt:
 
55
  from torch2trt import torch2trt
56
 
57
  model = torch2trt(
58
- model, [torch.randn((batch_size, max_length)).to("cuda", dtype=dtype)]
59
  )
60
 
61
  else:
62
  if torch.cuda.is_available():
63
- model.to("cuda", dtype=dtype)
 
64
  else:
65
- model.to("cpu", dtype=dtype)
66
  print("CUDA not available. Using CPU. This will be slow.")
 
67
 
68
  with tqdm(total=total_lines, desc="Dataset translation") as pbar, open(
69
  output_path, "w+", encoding="utf-8"
70
  ) as output_file:
71
  with torch.no_grad():
72
  for batch in data_loader:
 
73
  generated_tokens = model.generate(
74
  **batch, forced_bos_token_id=lang_code_to_idx
75
  )
 
52
  raise ValueError("Precision must be 16, 32 or 64.")
53
 
54
  if tensorrt:
55
+ device = "cuda"
56
  from torch2trt import torch2trt
57
 
58
  model = torch2trt(
59
+ model, [torch.randn((batch_size, max_length)).to(device, dtype=dtype)]
60
  )
61
 
62
  else:
63
  if torch.cuda.is_available():
64
+ device = "cuda"
65
+
66
  else:
67
+ device = "cpu"
68
  print("CUDA not available. Using CPU. This will be slow.")
69
+ model.to(device, dtype=dtype)
70
 
71
  with tqdm(total=total_lines, desc="Dataset translation") as pbar, open(
72
  output_path, "w+", encoding="utf-8"
73
  ) as output_file:
74
  with torch.no_grad():
75
  for batch in data_loader:
76
+ batch = batch.to(device, dtype=dtype)
77
  generated_tokens = model.generate(
78
  **batch, forced_bos_token_id=lang_code_to_idx
79
  )