Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from torch import nn | |
| from lora.lora_layers import LoraInjectedLinear, LoraInjectedConv2d | |
| def _find_modules(model, ancestor_class=None, search_class=[nn.Linear], exclude_children_of=[LoraInjectedLinear]): | |
| # Get the targets we should replace all linears under | |
| if ancestor_class is not None: | |
| ancestors = ( | |
| module | |
| for module in model.modules() | |
| if module.__class__.__name__ in ancestor_class | |
| ) | |
| else: | |
| # this, incase you want to naively iterate over all modules. | |
| ancestors = [module for module in model.modules()] | |
| for ancestor in ancestors: | |
| for fullname, module in ancestor.named_modules(): | |
| # if 'norm1_context' in fullname: | |
| if any([isinstance(module, _class) for _class in search_class]): | |
| *path, name = fullname.split(".") | |
| parent = ancestor | |
| while path: | |
| parent = parent.get_submodule(path.pop(0)) | |
| if exclude_children_of and any( | |
| [isinstance(parent, _class) for _class in exclude_children_of] | |
| ): | |
| continue | |
| yield parent, name, module | |
| def extract_lora_ups_down(model, target_replace_module={'AdaLayerNormZero'}): # Attention for kv_lora | |
| loras = [] | |
| for _m, _n, _child_module in _find_modules( | |
| model, | |
| target_replace_module, | |
| search_class=[LoraInjectedLinear, LoraInjectedConv2d], | |
| ): | |
| loras.append((_child_module.lora_up, _child_module.lora_down)) | |
| if len(loras) == 0: | |
| raise ValueError("No lora injected.") | |
| return loras | |
| def save_lora_weight( | |
| model, | |
| path="./lora.pt", | |
| target_replace_module={'AdaLayerNormZero'}, # Attention for kv_lora | |
| save_half:bool=False | |
| ): | |
| weights = [] | |
| for _up, _down in extract_lora_ups_down( | |
| model, target_replace_module=target_replace_module | |
| ): | |
| dtype = torch.float16 if save_half else torch.float32 | |
| weights.append(_up.weight.to("cpu").to(dtype)) | |
| weights.append(_down.weight.to("cpu").to(dtype)) | |
| torch.save(weights, path) |