|
from dataclasses import dataclass |
|
|
|
import loralib as lora |
|
|
|
|
|
@dataclass |
|
class LoraConfig: |
|
r: int |
|
lora_alpha: float |
|
lora_dropout: float = 0.0 |
|
|
|
|
|
def setup_lora(model, lora_config): |
|
|
|
model.embeddings = lora.Embedding( |
|
num_embeddings=model.embeddings.num_embeddings, |
|
embedding_dim=model.embeddings.embedding_dim, |
|
padding_idx=model.embeddings.padding_idx, |
|
r=lora_config.r, |
|
lora_alpha=lora_config.lora_alpha, |
|
) |
|
|
|
model.codebook_embeddings = lora.Embedding( |
|
num_embeddings=model.codebook_embeddings.num_embeddings, |
|
embedding_dim=model.codebook_embeddings.embedding_dim, |
|
padding_idx=model.codebook_embeddings.padding_idx, |
|
r=lora_config.r, |
|
lora_alpha=lora_config.lora_alpha, |
|
) |
|
|
|
|
|
linears = [(model, "output")] |
|
|
|
|
|
for layer in model.layers: |
|
linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")]) |
|
linears.extend( |
|
[ |
|
(layer.feed_forward, "w1"), |
|
(layer.feed_forward, "w2"), |
|
(layer.feed_forward, "w3"), |
|
] |
|
) |
|
|
|
if hasattr(model, "fast_layers"): |
|
model.fast_embeddings = lora.Embedding( |
|
num_embeddings=model.fast_embeddings.num_embeddings, |
|
embedding_dim=model.fast_embeddings.embedding_dim, |
|
padding_idx=model.fast_embeddings.padding_idx, |
|
r=lora_config.r, |
|
lora_alpha=lora_config.lora_alpha, |
|
) |
|
|
|
|
|
linears.append((model, "fast_output")) |
|
|
|
for layer in model.fast_layers: |
|
linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")]) |
|
linears.extend( |
|
[ |
|
(layer.feed_forward, "w1"), |
|
(layer.feed_forward, "w2"), |
|
(layer.feed_forward, "w3"), |
|
] |
|
) |
|
|
|
for module, layer in linears: |
|
updated_linear = lora.Linear( |
|
in_features=getattr(module, layer).in_features, |
|
out_features=getattr(module, layer).out_features, |
|
bias=getattr(module, layer).bias, |
|
r=lora_config.r, |
|
lora_alpha=lora_config.lora_alpha, |
|
lora_dropout=lora_config.lora_dropout, |
|
) |
|
setattr(module, layer, updated_linear) |
|
|
|
|
|
lora.mark_only_lora_as_trainable(model, bias="none") |
|
|
|
|
|
def get_merged_state_dict(model): |
|
|
|
model.eval() |
|
|
|
|
|
state_dict = model.state_dict() |
|
for name in list(state_dict.keys()): |
|
if "lora" in name: |
|
state_dict.pop(name) |
|
|
|
return state_dict |
|
|