feat: add torch compile feature
Browse files
train.py
CHANGED
@@ -154,5 +154,8 @@ detector = FontDetector(
|
|
154 |
num_epochs=num_epochs,
|
155 |
)
|
156 |
|
|
|
|
|
|
|
157 |
trainer.fit(detector, datamodule=data_module, ckpt_path=args.checkpoint)
|
158 |
trainer.test(detector, datamodule=data_module)
|
|
|
154 |
num_epochs=num_epochs,
|
155 |
)
|
156 |
|
157 |
+
if torch.__version__ >= "2.0":
|
158 |
+
detector = torch.compile(detector)
|
159 |
+
|
160 |
trainer.fit(detector, datamodule=data_module, ckpt_path=args.checkpoint)
|
161 |
trainer.test(detector, datamodule=data_module)
|