Update modeling_rwkv5.py
Browse files- modeling_rwkv5.py +20 -0
modeling_rwkv5.py
CHANGED
@@ -735,6 +735,26 @@ class Rwkv5Model(Rwkv5PreTrainedModel):
|
|
735 |
hidden_states=all_hidden_states, # None
|
736 |
attentions=all_self_attentions, # None
|
737 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
738 |
|
739 |
def _rescale_layers(self):
|
740 |
# Layers should be rescaled for inference only.
|
|
|
735 |
hidden_states=all_hidden_states, # None
|
736 |
attentions=all_self_attentions, # None
|
737 |
)
|
738 |
+
def _bnb_4bit_dequantize_and_rescale(self, target_layer, block_id):
|
739 |
+
r"""
|
740 |
+
Perform the dequantization and rescaling of the weights of a given layer. After that operation the layer will
|
741 |
+
be quantized again.
|
742 |
+
"""
|
743 |
+
if not is_bitsandbytes_available():
|
744 |
+
raise ImportError("Please install bitsandbytes to use this method.")
|
745 |
+
import bitsandbytes as bnb
|
746 |
+
|
747 |
+
dequant_weights = bnb.functional.dequantize_4bit(target_layer.weight.data, target_layer.weight.quant_state)
|
748 |
+
|
749 |
+
dequant_weights.div_(2 ** int(block_id // self.config.rescale_every))
|
750 |
+
|
751 |
+
# re-quantize the model:
|
752 |
+
# we need to put it first on CPU then back to the device
|
753 |
+
# this will create an overhead :/
|
754 |
+
# We set requires_grad=False as we cannot compute gradients on top of 4bit parameters anyway and to avoid
|
755 |
+
# bugs with bnb
|
756 |
+
quant_weight = bnb.nn.Params4bit(dequant_weights.to("cpu"), requires_grad=False).to(dequant_weights.device)
|
757 |
+
setattr(target_layer, "weight", quant_weight)
|
758 |
|
759 |
def _rescale_layers(self):
|
760 |
# Layers should be rescaled for inference only.
|