Iker commited on
Commit
136cf77
·
1 Parent(s): 1e19e28

Test torch2trt

Browse files
Files changed (1) hide show
  1. translate_troch2trt.py +156 -0
translate_troch2trt.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
2
+ from tqdm import tqdm
3
+ from typing import TextIO, List
4
+ import argparse
5
+ import torch
6
+ from dataset import get_dataloader, count_lines
7
+ import os
8
+
9
+
10
+ def main(
11
+ sentences_path,
12
+ output_path,
13
+ source_lang,
14
+ target_lang,
15
+ batch_size,
16
+ model_name: str = "facebook/m2m100_1.2B",
17
+ tensorrt: bool = False,
18
+ precision: int = 32,
19
+ max_length: int = 128,
20
+ ):
21
+
22
+ if not os.path.exists(os.path.dirname(output_path)):
23
+ os.makedirs(os.path.dirname(output_path))
24
+
25
+ print("Loading tokenizer...")
26
+ tokenizer = M2M100Tokenizer.from_pretrained(model_name)
27
+ print("Loading model...")
28
+ model = M2M100ForConditionalGeneration.from_pretrained(model_name)
29
+ print(f"Model loaded.\n")
30
+
31
+ tokenizer.src_lang = source_lang
32
+ lang_code_to_idx = tokenizer.lang_code_to_id[target_lang]
33
+
34
+ model.eval()
35
+
36
+ total_lines: int = count_lines(sentences_path)
37
+ print(f"We will translate {total_lines} lines.")
38
+ data_loader = get_dataloader(
39
+ filename=sentences_path,
40
+ tokenizer=tokenizer,
41
+ batch_size=batch_size,
42
+ max_length=128,
43
+ )
44
+
45
+ if precision == 16:
46
+ dtype = torch.float16
47
+ elif precision == 32:
48
+ dtype = torch.float32
49
+ elif precision == 64:
50
+ dtype = torch.float64
51
+ else:
52
+ raise ValueError("Precision must be 16, 32 or 64.")
53
+
54
+ if tensorrt:
55
+ from torch2trt import torch2trt
56
+
57
+ model = torch2trt(
58
+ model, [torch.randn((batch_size, max_length)).to("cuda", dtype=dtype)]
59
+ )
60
+
61
+ else:
62
+ if torch.cuda.is_available():
63
+ model.to("cuda", dtype=dtype)
64
+ else:
65
+ model.to("cpu", dtype=dtype)
66
+ print("CUDA not available. Using CPU. This will be slow.")
67
+
68
+ with tqdm(total=total_lines, desc="Dataset translation") as pbar, open(
69
+ output_path, "w+", encoding="utf-8"
70
+ ) as output_file:
71
+ with torch.no_grad():
72
+ for batch in data_loader:
73
+ generated_tokens = model.generate(
74
+ **batch, forced_bos_token_id=lang_code_to_idx
75
+ )
76
+ tgt_text = tokenizer.batch_decode(
77
+ generated_tokens.cpu(), skip_special_tokens=True
78
+ )
79
+
80
+ print("\n".join(tgt_text), file=output_file)
81
+
82
+ pbar.update(len(tgt_text))
83
+
84
+ print(f"Translation done.\n")
85
+
86
+
87
+ if __name__ == "__main__":
88
+ parser = argparse.ArgumentParser(description="Run the translation experiments")
89
+ parser.add_argument(
90
+ "--sentences_path",
91
+ type=str,
92
+ required=True,
93
+ help="Path to a txt file containing the sentences to translate. One sentence per line.",
94
+ )
95
+
96
+ parser.add_argument(
97
+ "--output_path",
98
+ type=str,
99
+ required=True,
100
+ help="Path to a txt file where the translated sentences will be written.",
101
+ )
102
+
103
+ parser.add_argument(
104
+ "--source_lang",
105
+ type=str,
106
+ required=True,
107
+ help="Source language id. See: https://huggingface.co/facebook/m2m100_1.2B",
108
+ )
109
+
110
+ parser.add_argument(
111
+ "--target_lang",
112
+ type=str,
113
+ required=True,
114
+ help="Target language id. See: https://huggingface.co/facebook/m2m100_1.2B",
115
+ )
116
+
117
+ parser.add_argument(
118
+ "--batch_size",
119
+ type=int,
120
+ default=8,
121
+ help="Batch size",
122
+ )
123
+
124
+ parser.add_argument(
125
+ "--model_name",
126
+ type=str,
127
+ default="facebook/m2m100_1.2B",
128
+ help="Path to the model to use. See: https://huggingface.co/models",
129
+ )
130
+
131
+ parser.add_argument(
132
+ "--precision",
133
+ type=int,
134
+ default=32,
135
+ choices=[16, 32, 64],
136
+ help="Precision of the model. 16, 32 or 64.",
137
+ )
138
+
139
+ parser.add_argument(
140
+ "--tensorrt",
141
+ action="store_true",
142
+ help="Use TensorRT to compile the model.",
143
+ )
144
+
145
+ args = parser.parse_args()
146
+
147
+ main(
148
+ sentences_path=args.sentences_path,
149
+ output_path=args.output_path,
150
+ source_lang=args.source_lang,
151
+ target_lang=args.target_lang,
152
+ batch_size=args.batch_size,
153
+ model_name=args.model_name,
154
+ precision=args.precision,
155
+ tensorrt=args.tensorrt,
156
+ )