Spaces:
Runtime error
Runtime error
Update tools/llama/generate.py
Browse files- tools/llama/generate.py +2 -1
tools/llama/generate.py
CHANGED
@@ -344,7 +344,8 @@ def load_model(checkpoint_path, device, precision, compile=False):
|
|
344 |
checkpoint_path, load_weights=True, weights_only=True
|
345 |
)
|
346 |
|
347 |
-
|
|
|
348 |
logger.info(f"Restored model from checkpoint")
|
349 |
|
350 |
if isinstance(model, DualARTransformer):
|
|
|
344 |
checkpoint_path, load_weights=True, weights_only=True
|
345 |
)
|
346 |
|
347 |
+
weights = torch.load(checkpoint_path, weights_only=True)
|
348 |
+
model = BaseTransformer.from_pretrained(weights)
|
349 |
logger.info(f"Restored model from checkpoint")
|
350 |
|
351 |
if isinstance(model, DualARTransformer):
|