gyrojeff commited on
Commit
eb2d25d
·
1 Parent(s): 01d9a57

feat: add torch compile feature

Browse files
Files changed (1) hide show
  1. train.py +3 -0
train.py CHANGED
@@ -154,5 +154,8 @@ detector = FontDetector(
154
  num_epochs=num_epochs,
155
  )
156
 
 
 
 
157
  trainer.fit(detector, datamodule=data_module, ckpt_path=args.checkpoint)
158
  trainer.test(detector, datamodule=data_module)
 
154
  num_epochs=num_epochs,
155
  )
156
 
157
+ if torch.__version__ >= "2.0":
158
+ detector = torch.compile(detector)
159
+
160
  trainer.fit(detector, datamodule=data_module, ckpt_path=args.checkpoint)
161
  trainer.test(detector, datamodule=data_module)