Spaces:
Build error
Build error
import re | |
import torch | |
from torch import nn | |
def freeze_whole_model(model): | |
for n, p in model.named_parameters(): | |
p.requires_grad = False | |
def unfreeze_parameters(model, config): | |
# targets = '*.proj_*|*_proj*|*itm_head*|*queue*|*adapter*|*temp*|*.cls.*' | |
targets = ['connector'] # lm_head | |
if config.get('unfreeze_text_layer_norm', False): | |
targets = targets + ['self_attn_layer_norm', 'final_layer_norm'] | |
if config.get('unfreeze_vision_layer_norm', False): | |
targets = targets + ['norm', 'norm1', 'norm2'] | |
print('unfreeze targets:', targets) | |
for n, p in model.named_parameters(): | |
if any(t in n for t in targets): | |
# if re.fullmatch(targets, n): | |
p.requires_grad = True | |
print(f"{n} is trainable...") | |
def print_trainable_params_percentage(model): | |
orig_param_size = sum(p.numel() for p in model.parameters()) | |
def count_parameters(model): | |
return sum(p.numel() for p in model.parameters() if p.requires_grad) | |
trainable_size = count_parameters(model) | |
percentage = trainable_size / orig_param_size * 100 | |
print(f"Trainable param percentage: {percentage:.2f}% ({trainable_size}/{orig_param_size})") | |
return percentage |