uto1125 commited on
Commit
cb462fd
·
verified ·
1 Parent(s): 271093a

Update tools/llama/generate.py

Browse files
Files changed (1) hide show
  1. tools/llama/generate.py +1 -1
tools/llama/generate.py CHANGED
@@ -341,7 +341,7 @@ def encode_tokens(
341
 
342
  def load_model(checkpoint_path, device, precision, compile=False):
343
  model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
344
- checkpoint_path, load_weights=True
345
  )
346
 
347
  model = model.to(device=device, dtype=precision)
 
341
 
342
  def load_model(checkpoint_path, device, precision, compile=False):
343
  model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
344
+ checkpoint_path, load_weights=True, weights_only=True
345
  )
346
 
347
  model = model.to(device=device, dtype=precision)