uto1125 commited on
Commit
eeb6993
·
verified ·
1 Parent(s): 3aa656e

Update tools/llama/generate.py

Browse files
Files changed (1) hide show
  1. tools/llama/generate.py +1 -4
tools/llama/generate.py CHANGED
@@ -340,10 +340,7 @@ def encode_tokens(
340
 
341
 
342
  def load_model(checkpoint_path, device, precision, compile=False):
343
- model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
344
- checkpoint_path, load_weights=True, weights_only=True
345
- )
346
-
347
  weights = torch.load(checkpoint_path, weights_only=True)
348
  model = BaseTransformer.from_pretrained(weights)
349
  logger.info(f"Restored model from checkpoint")
 
340
 
341
 
342
  def load_model(checkpoint_path, device, precision, compile=False):
343
+
 
 
 
344
  weights = torch.load(checkpoint_path, weights_only=True)
345
  model = BaseTransformer.from_pretrained(weights)
346
  logger.info(f"Restored model from checkpoint")