uto1125 commited on
Commit
3aa656e
·
verified ·
1 Parent(s): cb462fd

Update tools/llama/generate.py

Browse files
Files changed (1) hide show
  1. tools/llama/generate.py +2 -1
tools/llama/generate.py CHANGED
@@ -344,7 +344,8 @@ def load_model(checkpoint_path, device, precision, compile=False):
344
  checkpoint_path, load_weights=True, weights_only=True
345
  )
346
 
347
- model = model.to(device=device, dtype=precision)
 
348
  logger.info(f"Restored model from checkpoint")
349
 
350
  if isinstance(model, DualARTransformer):
 
344
  checkpoint_path, load_weights=True, weights_only=True
345
  )
346
 
347
+ weights = torch.load(checkpoint_path, weights_only=True)
348
+ model = BaseTransformer.from_pretrained(weights)
349
  logger.info(f"Restored model from checkpoint")
350
 
351
  if isinstance(model, DualARTransformer):