Update lit_llama/utils.py
Browse files- lit_llama/utils.py +3 -3
lit_llama/utils.py
CHANGED
|
@@ -89,13 +89,13 @@ class EmptyInitOnDevice(torch.overrides.TorchFunctionMode):
|
|
| 89 |
if self.quantization_mode == 'llm.int8':
|
| 90 |
if device.type != "cuda":
|
| 91 |
raise ValueError("Quantization is only supported on the GPU.")
|
| 92 |
-
from
|
| 93 |
self.quantized_linear_cls = Linear8bitLt
|
| 94 |
elif self.quantization_mode == 'gptq.int4':
|
| 95 |
-
from
|
| 96 |
self.quantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=4, tile_cols=-1)
|
| 97 |
elif self.quantization_mode == 'gptq.int8':
|
| 98 |
-
from
|
| 99 |
self.quantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=8, tile_cols=-1)
|
| 100 |
elif self.quantization_mode is not None:
|
| 101 |
raise RuntimeError(f"unknown quantization mode {self.quantization_mode}")
|
|
|
|
| 89 |
if self.quantization_mode == 'llm.int8':
|
| 90 |
if device.type != "cuda":
|
| 91 |
raise ValueError("Quantization is only supported on the GPU.")
|
| 92 |
+
from .quantization import Linear8bitLt
|
| 93 |
self.quantized_linear_cls = Linear8bitLt
|
| 94 |
elif self.quantization_mode == 'gptq.int4':
|
| 95 |
+
from .quantization import ColBlockQuantizedLinear
|
| 96 |
self.quantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=4, tile_cols=-1)
|
| 97 |
elif self.quantization_mode == 'gptq.int8':
|
| 98 |
+
from .quantization import ColBlockQuantizedLinear
|
| 99 |
self.quantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=8, tile_cols=-1)
|
| 100 |
elif self.quantization_mode is not None:
|
| 101 |
raise RuntimeError(f"unknown quantization mode {self.quantization_mode}")
|