Spaces:
Running
Running
import torch | |
import numpy as np | |
from abc import ABC, abstractmethod | |
from torch import nn | |
from hydra.utils import instantiate | |
import copy | |
from peft import LoraConfig, get_peft_model | |
from utils.model_utils import print_trainable_parameters | |
def freeze(model): | |
"""Freezes the parameters of a model.""" | |
for p in model.parameters(): | |
p.requires_grad = False | |
model.eval() | |
def unfreeze(model): | |
"""Unfreezes the parameters of a model. | |
for p in model.parameters(): | |
p.requires_grad = True""" | |
model_parameters = model.named_parameters() | |
for name, param in model_parameters: | |
if name in [ | |
"clip.vision_model.post_layernorm.weight", | |
"clip.vision_model.post_layernorm.bias", | |
]: | |
param.requires_grad = False | |
else: | |
param.requires_grad = True | |
model.train() | |
def unfreeze_last(model): | |
"""Unfreezes the parameters of a model. | |
for p in model.parameters(): | |
p.requires_grad = True""" | |
model_parameters = model.named_parameters() | |
for name, param in model_parameters: | |
if len(name.split(".")) > 5: | |
if name.split(".")[4] == "11": | |
param.requires_grad = True | |
else: | |
param.requires_grad = False | |
else: | |
param.requires_grad = False | |
model.train() | |
class FrozenBackbone(nn.Module): | |
"""Freezes the backbone of a network.""" | |
def __init__(self, backbone, mid, head): | |
super().__init__() | |
self.backbone = backbone.instance | |
self.mid = mid.instance | |
self.head = head.instance | |
self.target_key = head.target_key | |
freeze(self.backbone) | |
def forward(self, x): | |
"""Forward pass of the network. | |
x : Union[torch.Tensor, dict] with the output of the backbone. | |
""" | |
with torch.no_grad(): | |
x = self.backbone(x) | |
x = self.mid(x) | |
x = self.head(x) | |
return x | |
class UnfrozenBackbone(nn.Module): | |
"""Unfreezes the backbone of a network.""" | |
def __init__(self, backbone, mid, head): | |
super().__init__() | |
self.backbone = backbone.instance | |
self.mid = mid.instance | |
self.head = head.instance | |
self.target_key = head.target_key | |
unfreeze(self.backbone) | |
def forward(self, x): | |
"""Forward pass of the network. | |
x : Union[torch.Tensor, dict] with the output of the backbone. | |
""" | |
x = self.backbone(x) | |
x = self.mid(x) | |
x = self.head(x) | |
return x | |
class UnfrozenPartBackbone(nn.Module): | |
"""Unfreezes the backbone of a network.""" | |
def __init__(self, backbone, mid, head): | |
super().__init__() | |
self.backbone = backbone.instance | |
self.mid = mid.instance | |
self.head = head.instance | |
self.target_key = head.target_key | |
unfreeze_last(self.backbone) | |
def forward(self, x): | |
"""Forward pass of the network. | |
x : Union[torch.Tensor, dict] with the output of the backbone. | |
""" | |
x = self.backbone(x) | |
x = self.mid(x) | |
x = self.head(x) | |
return x | |
class NoFeatureBackbone(nn.Module): | |
"""Randomizes the backbone of a network.""" | |
def __init__(self, head): | |
super().__init__() | |
self.head = head.instance | |
self.target_key = head.target_key | |
def forward(self, x): | |
"""Forward pass of the network. | |
x : Union[torch.Tensor, dict] with the output of the backbone. | |
""" | |
return self.head(x) | |
class ContrastiveFrozenBackbone(FrozenBackbone): | |
"""Freezes the backbone of a network.""" | |
def __init__(self, backbone, mid, head, mode): | |
super().__init__(backbone, mid, head) | |
self.mode = mode | |
def forward(self, x): | |
with torch.no_grad(): | |
features = self.backbone(x) | |
if self.mode != "eval": | |
x_pos = { | |
k.strip("pos_"): v.clone() | |
if isinstance(v, torch.Tensor) | |
else copy.deepcopy(v) | |
for k, v in x.items() | |
if k.startswith("pos_") | |
} | |
pos_features = self.backbone(x_pos) | |
x = self.mid(features) | |
x = self.head(x) | |
if self.mode != "eval": | |
return { | |
"features": features[:, 0, :], | |
"pos_features": pos_features[:, 0, :], | |
**x, | |
} | |
return { | |
"features": features[:, 0, :], | |
**x, | |
} | |
class ContrastiveUnFrozenPartBackbone(UnfrozenPartBackbone): | |
"""Freezes the backbone of a network.""" | |
def __init__(self, backbone, mid, head, mode): | |
super().__init__(backbone, mid, head) | |
self.mode = mode | |
def forward(self, x): | |
features = self.backbone(x) | |
if self.mode != "eval": | |
x_pos = { | |
k.strip("pos_"): v.clone() | |
if isinstance(v, torch.Tensor) | |
else copy.deepcopy(v) | |
for k, v in x.items() | |
if k.startswith("pos_") | |
} | |
pos_features = self.backbone(x_pos) | |
x = self.mid(features) | |
x = self.head(x) | |
if self.mode != "eval": | |
return { | |
"features": features[:, 0, :], | |
"pos_features": pos_features[:, 0, :], | |
**x, | |
} | |
return { | |
"features": features[:, 0, :], | |
**x, | |
} | |
class ContrastiveUnFrozenBackbone(UnfrozenBackbone): | |
"""Freezes the backbone of a network.""" | |
def __init__(self, backbone, mid, head, mode): | |
super().__init__(backbone, mid, head) | |
self.mode = mode | |
def forward(self, x): | |
features = self.backbone(x) | |
if self.mode != "eval": | |
x_pos = { | |
k.strip("pos_"): v.clone() | |
if isinstance(v, torch.Tensor) | |
else copy.deepcopy(v) | |
for k, v in x.items() | |
if k.startswith("pos_") | |
} | |
pos_features = self.backbone(x_pos) | |
x = self.mid(features) | |
x = self.head(x) | |
if self.mode != "eval": | |
return { | |
"features": features[:, 0, :], | |
"pos_features": pos_features[:, 0, :], | |
**x, | |
} | |
return { | |
"features": features[:, 0, :], | |
**x, | |
} | |
class TextContrastiveUnFrozenBackbone(UnfrozenBackbone): | |
"""Freezes the backbone of a network.""" | |
def __init__(self, backbone, mid, head): | |
super().__init__(backbone, mid, head) | |
def forward(self, x): | |
con, features = self.backbone(x) | |
x = self.mid(features) | |
x = self.head(x) | |
return { | |
"features": con, | |
**x, | |
} | |
class LoraBackbone(nn.Module): | |
"""Wraps the backbone in a PEFT model for LoRA tuning.""" | |
def __init__(self, backbone, mid, head, r, alpha, dropout, bias): | |
super().__init__() | |
self.backbone = backbone.instance | |
self.mid = mid.instance | |
self.head = head.instance | |
self.target_key = head.target_key | |
freeze(self.backbone) | |
config = LoraConfig( | |
r=r, | |
lora_alpha=alpha, | |
lora_dropout=dropout, | |
bias=bias, | |
target_modules=["q_proj", "k_proj", "v_proj"], | |
) | |
self.backbone = get_peft_model(self.backbone, config) | |
print_trainable_parameters(self) | |
def forward(self, x): | |
"""Forward pass of the network. | |
x : Union[torch.Tensor, dict] with the output of the backbone. | |
""" | |
x = self.backbone(x) | |
x = self.mid(x) | |
return self.head(x) | |
class HybridFrozenBackbone(FrozenBackbone): | |
"""Freezes the backbone of a network.""" | |
def forward(self, x): | |
"""Forward pass of the network. | |
x : Union[torch.Tensor, dict] with the output of the backbone. | |
""" | |
gt_label = x["label"] if self.training else None | |
with torch.no_grad(): | |
x = self.backbone(x) | |
x = self.mid(x) | |
x = self.head(x, gt_label) | |
return x | |
class HybridUnfrozenBackbone(UnfrozenBackbone): | |
"""Unfreezes the backbone of a network.""" | |
def forward(self, x): | |
"""Forward pass of the network. | |
x : Union[torch.Tensor, dict] with the output of the backbone. | |
""" | |
gt_label = x["label"] if self.training else None | |
x = self.backbone(x) | |
x = self.mid(x) | |
x = self.head(x, gt_label) | |
return x | |
class ContrastiveHybridUnFrozenBackbone(UnfrozenBackbone): | |
"""Freezes the backbone of a network.""" | |
def __init__(self, backbone, mid, head, mode): | |
super().__init__(backbone, mid, head) | |
self.mode = mode | |
def forward(self, x): | |
gt_label = x["label"] if self.training else None | |
features = self.backbone(x) | |
if self.mode != "eval": | |
x_pos = { | |
k.strip("pos_"): v.clone() | |
if isinstance(v, torch.Tensor) | |
else copy.deepcopy(v) | |
for k, v in x.items() | |
if k.startswith("pos_") | |
} | |
pos_features = self.backbone(x_pos) | |
x = self.mid(features) | |
x = self.head(x, gt_label) | |
if self.mode != "eval": | |
return { | |
"features": features[:, 0, :], | |
"pos_features": pos_features[:, 0, :], | |
**x, | |
} | |
return { | |
"features": features[:, 0, :], | |
**x, | |
} | |