Upload PAL_B_RM_opt
Browse files- config.json +4 -0
- configuration_pal_b_rm.py +35 -0
- connector.py +112 -0
- custom_sfx.py +61 -0
- itemLearner.py +53 -0
- learner.py +138 -0
- modeling_pal_b_rm.py +28 -0
- projector.py +81 -0
- pytorch_model.bin +1 -1
- tensor_initializer.py +32 -0
- tensor_merger.py +16 -0
- userLearner.py +151 -0
config.json
CHANGED
@@ -2,6 +2,10 @@
|
|
2 |
"architectures": [
|
3 |
"PAL_B_RM_opt"
|
4 |
],
|
|
|
|
|
|
|
|
|
5 |
"d_hid": 512,
|
6 |
"d_pref": 512,
|
7 |
"initializer_type": "gaussian",
|
|
|
2 |
"architectures": [
|
3 |
"PAL_B_RM_opt"
|
4 |
],
|
5 |
+
"auto_map": {
|
6 |
+
"AutoConfig": "configuration_pal_b_rm.PAL_B_Config",
|
7 |
+
"AutoModel": "modeling_pal_b_rm.PAL_B_RM_opt"
|
8 |
+
},
|
9 |
"d_hid": 512,
|
10 |
"d_pref": 512,
|
11 |
"initializer_type": "gaussian",
|
configuration_pal_b_rm.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
|
3 |
+
class PAL_B_Config(PretrainedConfig):
|
4 |
+
model_type = "facebook/opt" # opt family model aligned PAL reward model
|
5 |
+
|
6 |
+
def __init__(
|
7 |
+
self,
|
8 |
+
d_hid: int = 512,
|
9 |
+
d_pref: int = 512,
|
10 |
+
k: int = 2,
|
11 |
+
llm_name: str = "facebook/opt-350m",
|
12 |
+
pref_learner_type: str = "angle",
|
13 |
+
proj_arch: str = "mlp2-gelu-dropout0",
|
14 |
+
initializer_type: str = "gaussian",
|
15 |
+
is_expectation_norm_init: bool = False,
|
16 |
+
sfx_type: str = "softmax",
|
17 |
+
sfx_temperature: float = 1.0,
|
18 |
+
is_temperature_learnable: bool = False,
|
19 |
+
is_gumbel_hard: bool = None,
|
20 |
+
**kwargs,
|
21 |
+
):
|
22 |
+
self.d_hid = d_hid
|
23 |
+
self.d_pref = d_pref
|
24 |
+
self.k = k
|
25 |
+
self.llm_name = llm_name
|
26 |
+
self.pref_learner_type = pref_learner_type
|
27 |
+
self.proj_arch = proj_arch
|
28 |
+
self.initializer_type = initializer_type
|
29 |
+
self.is_expectation_norm_init = is_expectation_norm_init
|
30 |
+
self.sfx_type = sfx_type
|
31 |
+
self.sfx_temperature = sfx_temperature
|
32 |
+
self.is_temperature_learnable = is_temperature_learnable
|
33 |
+
self.is_gumbel_hard = is_gumbel_hard
|
34 |
+
super().__init__(**kwargs)
|
35 |
+
|
connector.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .projector import Projector
|
2 |
+
import torch.nn as nn
|
3 |
+
import re
|
4 |
+
|
5 |
+
|
6 |
+
# class Connector(nn.Module):
|
7 |
+
# def __init__(self, cnct_arch:str, in_dims:int, out_dims:int):
|
8 |
+
# super().__init__()
|
9 |
+
# # projector_type structure ["mlp?-relu-dropout?-residual","identity"]
|
10 |
+
# self.cnct_arch = cnct_arch
|
11 |
+
|
12 |
+
# if cnct_arch == 'identity':
|
13 |
+
# self.m = nn.Identity()
|
14 |
+
|
15 |
+
# pattern = r"mlp(\d+)-(relu|gelu|linear)-dropout(\d+)?(-residual-batchnorm|-batchnorm-residual|-residual|-batchnorm|-nobias)?"
|
16 |
+
# match = re.match(pattern, cnct_arch)
|
17 |
+
|
18 |
+
# if match:
|
19 |
+
# layers = int(match.group(1))
|
20 |
+
# act = match.group(2)
|
21 |
+
# dropout_p = int(match.group(3))
|
22 |
+
# num_digit = len(match.group(3))
|
23 |
+
# dropout_p = dropout_p / 10**num_digit
|
24 |
+
# # print("match.group(4): ", match.group(4))
|
25 |
+
# nobias = False
|
26 |
+
# if match.group(4) != None:
|
27 |
+
# residual = True if ("-residual" in match.group(4)) else False
|
28 |
+
# batchnorm = True if ("-batchnorm" in match.group(4)) else False
|
29 |
+
# nobias = True if ("-nobias" in match.group(4)) else False
|
30 |
+
# else:
|
31 |
+
# residual = False
|
32 |
+
# batchnorm = False
|
33 |
+
# latent_dims = [out_dims] * layers
|
34 |
+
# self.m = Projector(
|
35 |
+
# in_dims=in_dims,
|
36 |
+
# out_dims=out_dims,
|
37 |
+
# latent_dims=latent_dims,
|
38 |
+
# bias=not nobias,
|
39 |
+
# dropout_p=dropout_p,
|
40 |
+
# activation=act,
|
41 |
+
# identity_map=residual,
|
42 |
+
# use_batchnorm=batchnorm,
|
43 |
+
# )
|
44 |
+
|
45 |
+
# def forward(self,x):
|
46 |
+
# return self.m(x)
|
47 |
+
|
48 |
+
|
49 |
+
class Connector(nn.Module):
|
50 |
+
def __init__(self, in_dims: int, out_dims: int, cnct_arch:str):
|
51 |
+
super().__init__()
|
52 |
+
pattern = r"mlp(\d+)-(relu|gelu|linear)-dropout(\d+)?(-residual-batchnorm|-batchnorm-residual|-residual|-batchnorm|-nobias)?"
|
53 |
+
match = re.match(pattern, cnct_arch)
|
54 |
+
if match:
|
55 |
+
layers = int(match.group(1))
|
56 |
+
act = match.group(2)
|
57 |
+
dropout_p = int(match.group(3))
|
58 |
+
num_digit = len(match.group(3))
|
59 |
+
dropout_p = dropout_p / 10**num_digit
|
60 |
+
if match.group(4) != None:
|
61 |
+
residual = True if ("-residual" in match.group(4)) else False
|
62 |
+
batchnorm = True if ("-batchnorm" in match.group(4)) else False
|
63 |
+
nobias = True if ("-nobias" in match.group(4)) else False
|
64 |
+
else:
|
65 |
+
residual = False
|
66 |
+
batchnorm = False
|
67 |
+
nobias = False
|
68 |
+
latent_dims = [out_dims] * layers
|
69 |
+
self.mlp = Projector(
|
70 |
+
in_dims=in_dims,
|
71 |
+
out_dims=out_dims,
|
72 |
+
latent_dims=latent_dims,
|
73 |
+
bias=not nobias,
|
74 |
+
dropout_p=dropout_p,
|
75 |
+
activation=act,
|
76 |
+
identity_map=residual,
|
77 |
+
use_batchnorm=batchnorm,
|
78 |
+
)
|
79 |
+
elif cnct_arch == 'identity':
|
80 |
+
self.mlp = nn.Identity()
|
81 |
+
else:
|
82 |
+
raise ValueError(f'no such connection architecture {cnct_arch}')
|
83 |
+
|
84 |
+
def __call__(self, x):
|
85 |
+
ret = self.mlp(x)
|
86 |
+
return ret
|
87 |
+
|
88 |
+
if __name__ == "__main__":
|
89 |
+
m = Connector(cnct_arch='identity',in_dims=4096,out_dims=768)
|
90 |
+
print(m)
|
91 |
+
m = Connector(cnct_arch='mlp1-relu-dropout2-residual',in_dims=4096,out_dims=768)
|
92 |
+
print(m)
|
93 |
+
m = Connector(cnct_arch='mlp1-relu-dropout2-batchnorm',in_dims=4096,out_dims=768)
|
94 |
+
print(m)
|
95 |
+
m = Connector(cnct_arch='mlp1-relu-dropout2-residual-batchnorm',in_dims=4096,out_dims=768)
|
96 |
+
print(m)
|
97 |
+
m = Connector(cnct_arch='mlp3-gelu-dropout2',in_dims=4096,out_dims=768)
|
98 |
+
print(m)
|
99 |
+
m = Connector(cnct_arch='mlp16-relu-dropout75',in_dims=4096,out_dims=768)
|
100 |
+
print(m)
|
101 |
+
m = Connector(cnct_arch='mlp0-linear-dropout0', in_dims=4096, out_dims=768)
|
102 |
+
print(m)
|
103 |
+
m = Connector(cnct_arch='mlp0-linear-dropout0-nobias', in_dims=4096, out_dims=768)
|
104 |
+
print(m)
|
105 |
+
m = Connector(cnct_arch='mlp2-linear-dropout0-nobias', in_dims=4096, out_dims=768)
|
106 |
+
print(m)
|
107 |
+
|
108 |
+
m = Connector(cnct_arch='mlp2-gelu-dropout0', in_dims=512, out_dims=512)
|
109 |
+
count = 0
|
110 |
+
for p in m.parameters():
|
111 |
+
count += p.numel()
|
112 |
+
print(count)
|
custom_sfx.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from typing import Literal, Optional
|
6 |
+
|
7 |
+
import logging
|
8 |
+
logging.basicConfig(level=logging.WARNING, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
class CustomSoftMax(nn.Module):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
sfx_type: Literal['gumbel_softmax', 'softmax'],
|
15 |
+
temperature: float,
|
16 |
+
is_temperature_learnable: bool,
|
17 |
+
is_gumbel_hard: Optional[bool]=None, # [True/False]
|
18 |
+
*args,
|
19 |
+
**kwargs,
|
20 |
+
) -> None:
|
21 |
+
|
22 |
+
super().__init__()
|
23 |
+
self.sfx_type = sfx_type
|
24 |
+
assert not is_temperature_learnable, 'is_temperature_learnable is prohibited in this version, will go to negative'
|
25 |
+
self.temperature = nn.Parameter(torch.tensor([float(temperature)]),requires_grad=is_temperature_learnable)
|
26 |
+
self.is_gumbel_hard = is_gumbel_hard
|
27 |
+
self.args = args
|
28 |
+
self.kwargs = kwargs
|
29 |
+
|
30 |
+
def forward(self, x):
|
31 |
+
# x: (bs, dims)
|
32 |
+
if self.sfx_type == 'gumbel_softmax':
|
33 |
+
if self.is_gumbel_hard is not None:
|
34 |
+
return F.gumbel_softmax(x, tau=self.temperature, hard=self.is_gumbel_hard, dim=1)
|
35 |
+
else:
|
36 |
+
raise ValueError('is_gumbel_hard is not passed')
|
37 |
+
elif self.sfx_type == 'softmax':
|
38 |
+
return F.softmax(x/self.temperature, dim=1)
|
39 |
+
else:
|
40 |
+
raise NotImplementedError(f'{self.sfx_type} is not implemented yet')
|
41 |
+
|
42 |
+
if __name__ == "__main__":
|
43 |
+
|
44 |
+
sfx = CustomSoftMax(sfx_type='gumbel_softmax', temperature=1, is_temperature_learnable=False, is_gumbel_hard=True)
|
45 |
+
x = torch.randn(10,3) # (bs, dims)
|
46 |
+
print(x.shape)
|
47 |
+
print(sfx(x))
|
48 |
+
|
49 |
+
sfx = CustomSoftMax(sfx_type='gumbel_softmax', temperature=1, is_temperature_learnable=True, is_gumbel_hard=True)
|
50 |
+
x = torch.randn(10,3) # (bs, dims)
|
51 |
+
print(x.shape)
|
52 |
+
print(sfx(x))
|
53 |
+
|
54 |
+
sfx = CustomSoftMax(sfx_type='softmax', temperature=1, is_temperature_learnable=False)
|
55 |
+
x = torch.randn(10,3)
|
56 |
+
print(sfx(x))
|
57 |
+
|
58 |
+
sfx = CustomSoftMax(sfx_type='softmax',temperature=0.01, is_temperature_learnable=True, is_gumbel_hard=None)
|
59 |
+
x = torch.randn(10,3)
|
60 |
+
print(sfx(x))
|
61 |
+
|
itemLearner.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from .connector import Connector
|
5 |
+
from .projector import Projector
|
6 |
+
from .tensor_merger import TensorMerger
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from typing import Literal, Optional, Tuple
|
10 |
+
import logging
|
11 |
+
logging.basicConfig(level=logging.WARNING, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
class ItemLearner(nn.Module):
|
15 |
+
|
16 |
+
llm: nn.Module
|
17 |
+
projector: nn.Module
|
18 |
+
|
19 |
+
def __init__(self, llm, projector):
|
20 |
+
super().__init__()
|
21 |
+
self.llm = llm
|
22 |
+
self.projector = projector
|
23 |
+
|
24 |
+
def forward(self, x, rm_cached=None): # only pass the generated data
|
25 |
+
'''
|
26 |
+
x = {'input_ids': torch.tensor, 'attention_mask': torch.tensor}
|
27 |
+
'''
|
28 |
+
input_ids = x['input_ids']
|
29 |
+
attention_mask = x['attention_mask']
|
30 |
+
|
31 |
+
if rm_cached is None:
|
32 |
+
llm_res = self.llm(
|
33 |
+
input_ids=input_ids,
|
34 |
+
attention_mask=attention_mask,
|
35 |
+
)
|
36 |
+
else:
|
37 |
+
llm_res = self.llm(
|
38 |
+
input_ids=input_ids[:, -1:], # attention_mask=attention_mask,
|
39 |
+
past_key_values=rm_cached["item_learner"],
|
40 |
+
use_cache=False
|
41 |
+
)
|
42 |
+
rm_cached["item_learner"] = llm_res.past_key_values
|
43 |
+
|
44 |
+
embeds = llm_res.last_hidden_state
|
45 |
+
# embeds shape: (bs, seq_len, hidden_size)
|
46 |
+
shape = embeds.shape
|
47 |
+
embeds = embeds.view(-1, shape[-1]) # (bs*seq_len, hidden_size)
|
48 |
+
projected_embeds = self.projector(embeds)
|
49 |
+
|
50 |
+
if rm_cached is None:
|
51 |
+
return projected_embeds.view(shape[0], shape[1], -1)
|
52 |
+
else:
|
53 |
+
return projected_embeds.view(shape[0], shape[1], -1), rm_cached
|
learner.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*-coding:utf-8 -*-
|
3 |
+
|
4 |
+
'''
|
5 |
+
@Desc: This is the implementation of PAL-B
|
6 |
+
'''
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from transformers import AutoModel, AutoConfig
|
13 |
+
|
14 |
+
from .connector import Connector
|
15 |
+
from .tensor_initializer import TensorInitializer
|
16 |
+
from .custom_sfx import CustomSoftMax
|
17 |
+
from .itemLearner import ItemLearner
|
18 |
+
from .userLearner import UserLearner
|
19 |
+
|
20 |
+
from collections import defaultdict
|
21 |
+
from typing import Literal, Optional, Tuple
|
22 |
+
|
23 |
+
import logging
|
24 |
+
logger = logging.getLogger(__name__)
|
25 |
+
|
26 |
+
class BasePrefLearner(nn.Module):
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
d_hid: int,
|
30 |
+
d_pref: int,
|
31 |
+
k: int,
|
32 |
+
llm_name: str,
|
33 |
+
pref_learner_type: Literal["dist","dist_normalization","angle","norm","dist_logistic","angle_hinge"],
|
34 |
+
proj_arch: str,
|
35 |
+
initializer_type: Literal["gaussian"],
|
36 |
+
is_expectation_norm_init: bool, # the tensor initialization parameters
|
37 |
+
sfx_type: Literal["gumbel_softmax", "softmax"],
|
38 |
+
sfx_temperature: float,
|
39 |
+
is_temperature_learnable: bool,
|
40 |
+
is_gumbel_hard: Optional[bool]=None,
|
41 |
+
is_partition: bool=False,
|
42 |
+
seed: int=42,
|
43 |
+
**kwargs
|
44 |
+
):
|
45 |
+
super().__init__()
|
46 |
+
self.pref_learner_type = pref_learner_type
|
47 |
+
self.is_temperature_learnable = is_temperature_learnable
|
48 |
+
# init all necessary modules
|
49 |
+
model_config = AutoConfig.from_pretrained(llm_name)
|
50 |
+
self.llm = AutoModel.from_pretrained(llm_name,from_tf=bool(".ckpt" in llm_name),config=model_config)
|
51 |
+
self.tensor_initializer = TensorInitializer(initializer_type, seed, is_expectation_norm_init=is_expectation_norm_init)
|
52 |
+
self.projector_f = Connector(cnct_arch=proj_arch,in_dims=d_hid,out_dims=d_pref)
|
53 |
+
self.projectors_gk = [Connector(cnct_arch=proj_arch,in_dims=d_hid,out_dims=d_pref) for _ in range(k)]
|
54 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
55 |
+
self.softmax_w = CustomSoftMax(sfx_type=sfx_type,
|
56 |
+
temperature=sfx_temperature,
|
57 |
+
is_temperature_learnable=is_temperature_learnable,
|
58 |
+
is_gumbel_hard=is_gumbel_hard)
|
59 |
+
self.item_learner = ItemLearner(
|
60 |
+
llm = self.llm,
|
61 |
+
projector=self.projector_f,
|
62 |
+
)
|
63 |
+
self.is_partition = is_partition
|
64 |
+
self.user_learner = UserLearner(k=k, llm=self.llm, projectors=self.projectors_gk, softmax=self.softmax_w, is_partition=is_partition)
|
65 |
+
logger.critical('π Remember to call update_trainable_params() after the model is initialized.')
|
66 |
+
|
67 |
+
def update_trainable_params(self, fix_modules: Tuple[str,...]=()):
|
68 |
+
# capture params
|
69 |
+
self.trainable_params = defaultdict(list)
|
70 |
+
if "llm" not in fix_modules:
|
71 |
+
self.trainable_params["llm"] = self.llm.parameters()
|
72 |
+
else:
|
73 |
+
self.llm.eval()
|
74 |
+
if "itemLearnerProjector" not in fix_modules:
|
75 |
+
self.trainable_params["projector_f"].extend(self.item_learner.projector.parameters())
|
76 |
+
if "userLearnerProjector" not in fix_modules:
|
77 |
+
self.trainable_params["projectors_gk"].extend(list(self.user_learner.projectors.parameters()))
|
78 |
+
if "W" not in fix_modules:
|
79 |
+
self.trainable_params["W"] = self.user_learner.W.parameters()
|
80 |
+
if self.pref_learner_type in ["angle","dist_logistic"] and "logit_scale" not in fix_modules:
|
81 |
+
self.trainable_params["logit_scale"] = self.logit_scale
|
82 |
+
if self.is_temperature_learnable and "temperature" not in fix_modules:
|
83 |
+
self.trainable_params["temperature"] = self.softmax_w.temperature
|
84 |
+
|
85 |
+
def map_to_pref_embedding_space(self, x, rm_cached=None):
|
86 |
+
# ({
|
87 |
+
# 'input_ids': prompt_input_ids,\
|
88 |
+
# 'attention_mask': prompt_attention_mask,
|
89 |
+
# },\
|
90 |
+
# {
|
91 |
+
# 'input_ids': eval_input_ids,\
|
92 |
+
# 'attention_mask': eval_attention_mask,\
|
93 |
+
# })
|
94 |
+
prompt, items = x
|
95 |
+
if rm_cached is None:
|
96 |
+
items_prime = self.item_learner(items)
|
97 |
+
prompt_prime = self.user_learner(prompt)
|
98 |
+
return items_prime, prompt_prime
|
99 |
+
else:
|
100 |
+
items_prime, rm_cached = self.item_learner(items, rm_cached)
|
101 |
+
prompt_prime, rm_cached = self.user_learner(prompt, rm_cached)
|
102 |
+
return items_prime, prompt_prime, rm_cached
|
103 |
+
|
104 |
+
|
105 |
+
class PrefLearner(BasePrefLearner): # <f(x),f(u)>
|
106 |
+
|
107 |
+
def __init__(self,*args, **kwargs):
|
108 |
+
super().__init__(*args, **kwargs)
|
109 |
+
|
110 |
+
def forward(self, x, rm_cached=None):
|
111 |
+
items, prompt = x
|
112 |
+
if rm_cached is None:
|
113 |
+
items_prime, prompt_prime = self.map_to_pref_embedding_space((prompt, items))
|
114 |
+
else:
|
115 |
+
items_prime, prompt_prime, rm_cached = self.map_to_pref_embedding_space((prompt, items), rm_cached)
|
116 |
+
logger.info(f"{items_prime[0]=}")
|
117 |
+
logger.info(f"{prompt_prime[0]=}")
|
118 |
+
logger.info(f"{items_prime.shape=}")
|
119 |
+
logger.info(f"{prompt_prime.shape=}")
|
120 |
+
if self.pref_learner_type == 'angle':
|
121 |
+
prompt_last_prime = prompt_prime[:, -1, :]
|
122 |
+
prompt_last_prime = prompt_last_prime.unsqueeze(1)
|
123 |
+
prompt_last_prime = prompt_last_prime / torch.norm(prompt_last_prime, dim=-1, keepdim=True)
|
124 |
+
items_last_prime = items_prime[:, -1, :]
|
125 |
+
items_last_prime = items_last_prime.unsqueeze(1)
|
126 |
+
items_last_prime = items_last_prime / torch.norm(items_last_prime, dim=-1, keepdim=True)
|
127 |
+
logit_scale = self.logit_scale.exp()
|
128 |
+
clamped_logit_scale = torch.clamp(logit_scale, max=100)
|
129 |
+
logger.info(f"{prompt_last_prime.shape=}")
|
130 |
+
logger.info(f"{items_last_prime.shape=}")
|
131 |
+
sim_score = (prompt_last_prime * items_last_prime).sum(dim=-1) * clamped_logit_scale # (bs, max_token_length)
|
132 |
+
if rm_cached is None:
|
133 |
+
return sim_score
|
134 |
+
else:
|
135 |
+
return sim_score, rm_cached
|
136 |
+
else:
|
137 |
+
raise NotImplementedError
|
138 |
+
|
modeling_pal_b_rm.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PreTrainedModel
|
2 |
+
from .learner import PrefLearner
|
3 |
+
from .configuration_pal_b_rm import PAL_B_Config
|
4 |
+
|
5 |
+
class PAL_B_RM_opt(PreTrainedModel):
|
6 |
+
config_class = PAL_B_Config
|
7 |
+
|
8 |
+
def __init__(self, config):
|
9 |
+
super().__init__(config)
|
10 |
+
self.model = PrefLearner(
|
11 |
+
d_hid=config.d_hid,
|
12 |
+
d_pref=config.d_pref,
|
13 |
+
k=config.k,
|
14 |
+
llm_name=config.llm_name,
|
15 |
+
pref_learner_type=config.pref_learner_type,
|
16 |
+
proj_arch=config.proj_arch,
|
17 |
+
initializer_type=config.initializer_type,
|
18 |
+
is_expectation_norm_init=config.is_expectation_norm_init,
|
19 |
+
sfx_type=config.sfx_type,
|
20 |
+
sfx_temperature=config.sfx_temperature,
|
21 |
+
is_temperature_learnable=config.is_temperature_learnable,
|
22 |
+
is_gumbel_hard=config.is_gumbel_hard,
|
23 |
+
)
|
24 |
+
# self.model.user_learner.init_weight(uids)
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
logits = self.model(x)
|
28 |
+
return {'logits': logits}
|
projector.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Sequence
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
class Projector(nn.Module):
|
7 |
+
|
8 |
+
in_dims: int
|
9 |
+
out_dims: int
|
10 |
+
latent_dims: Sequence[int]
|
11 |
+
bias: bool
|
12 |
+
dropout_p: float
|
13 |
+
activation: str
|
14 |
+
identity_map: bool
|
15 |
+
use_batchnorm: bool
|
16 |
+
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
in_dims: int,
|
20 |
+
out_dims: int,
|
21 |
+
latent_dims: Sequence[int] = tuple([]),
|
22 |
+
bias: bool = True,
|
23 |
+
dropout_p: float = 0.2,
|
24 |
+
activation:str='relu',
|
25 |
+
identity_map=False,
|
26 |
+
use_batchnorm: bool = False,
|
27 |
+
):
|
28 |
+
super().__init__()
|
29 |
+
|
30 |
+
self.in_dims = in_dims
|
31 |
+
self.out_dims = out_dims
|
32 |
+
self.bias = bias
|
33 |
+
self.dropout_p = dropout_p
|
34 |
+
self.latent_dims = latent_dims
|
35 |
+
self.act = None
|
36 |
+
self.identity_map = identity_map
|
37 |
+
self.use_batchnorm = use_batchnorm
|
38 |
+
|
39 |
+
if activation == 'relu':
|
40 |
+
self.act = nn.ReLU
|
41 |
+
elif activation == 'gelu':
|
42 |
+
self.act = nn.GELU
|
43 |
+
elif activation == 'linear':
|
44 |
+
self.act = nn.Identity
|
45 |
+
else:
|
46 |
+
raise ValueError(f'no such activation {activation}')
|
47 |
+
|
48 |
+
if identity_map == True:
|
49 |
+
self.identity = nn.Identity()
|
50 |
+
# self.alpha = nn.Parameter(torch.tensor(0.5))
|
51 |
+
|
52 |
+
layer_dims = [in_dims] + list(latent_dims)
|
53 |
+
layers = []
|
54 |
+
|
55 |
+
for i in range(len(layer_dims) - 1):
|
56 |
+
layers.append(nn.Linear(layer_dims[i], layer_dims[i + 1], bias=self.bias))
|
57 |
+
if self.use_batchnorm: # Add batch normalization layer if enabled
|
58 |
+
layers.append(nn.BatchNorm1d(layer_dims[i + 1]))
|
59 |
+
layers.extend([
|
60 |
+
nn.Dropout(p=self.dropout_p),
|
61 |
+
self.act()
|
62 |
+
])
|
63 |
+
|
64 |
+
layers.append(nn.Linear(layer_dims[-1], out_dims, bias=self.bias))
|
65 |
+
self.layers = nn.Sequential(*layers)
|
66 |
+
|
67 |
+
def forward(self, x) -> torch.Tensor:
|
68 |
+
"""Forward pass of the projector model.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
x: The input features.
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
torch.Tensor: The projected features.
|
75 |
+
|
76 |
+
"""
|
77 |
+
if self.identity_map:
|
78 |
+
x = self.identity(x) + self.layers(x)
|
79 |
+
else:
|
80 |
+
x = self.layers(x)
|
81 |
+
return x
|
pytorch_model.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1334487698
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:770877c170b8b51c6e6555de658213f0a6a1fca5c74370f1b8fed47cf6411bac
|
3 |
size 1334487698
|
tensor_initializer.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
|
4 |
+
class TensorInitializer:
|
5 |
+
|
6 |
+
def __init__(self, type: str, seed: int, is_expectation_norm_init: bool = False):
|
7 |
+
self.initializer_type = type
|
8 |
+
self.rng = np.random.default_rng(seed)
|
9 |
+
self.is_expectation_norm_init = is_expectation_norm_init
|
10 |
+
|
11 |
+
def gaussian_initializer(
|
12 |
+
self,
|
13 |
+
dim: int,
|
14 |
+
size: int,
|
15 |
+
) -> torch.Tensor:
|
16 |
+
|
17 |
+
mean = np.zeros(dim)
|
18 |
+
if self.is_expectation_norm_init:
|
19 |
+
# expectation normalization
|
20 |
+
cov = 1 / dim * np.eye(dim)
|
21 |
+
return torch.tensor(self.rng.multivariate_normal(mean, cov, size), dtype=torch.float32)#.float()
|
22 |
+
else:
|
23 |
+
# enforced normalization
|
24 |
+
cov = np.eye(dim)
|
25 |
+
unnorm_tensor = torch.tensor(self.rng.multivariate_normal(mean, cov, size), dtype=torch.float32)#.float()
|
26 |
+
return unnorm_tensor / torch.norm(unnorm_tensor, dim=1, keepdim=True)
|
27 |
+
|
28 |
+
def __call__(self, *args, **kwargs):
|
29 |
+
if self.initializer_type == 'gaussian':
|
30 |
+
return self.gaussian_initializer(*args, **kwargs)
|
31 |
+
else:
|
32 |
+
raise ValueError(f'Unknown initializer type: {self.initializer_type}')
|
tensor_merger.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class TensorMerger:
|
5 |
+
|
6 |
+
def __init__(self, merger_type) -> None:
|
7 |
+
self.merger_type = merger_type
|
8 |
+
|
9 |
+
def concat(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
10 |
+
return torch.cat([x, y], dim=1)
|
11 |
+
|
12 |
+
def __call__(self, x: torch.Tensor, y: torch.Tensor):
|
13 |
+
if self.merger_type == 'concat':
|
14 |
+
return self.concat(x,y)
|
15 |
+
else:
|
16 |
+
raise ValueError(f'Unknown merger type: {self.merger_type}')
|
userLearner.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from .connector import Connector
|
5 |
+
from .projector import Projector
|
6 |
+
from .tensor_initializer import TensorInitializer
|
7 |
+
from .custom_sfx import CustomSoftMax
|
8 |
+
import numpy as np
|
9 |
+
import warnings
|
10 |
+
|
11 |
+
from typing import Literal
|
12 |
+
|
13 |
+
import logging
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
class UserLearner(nn.Module):
|
17 |
+
|
18 |
+
k: int # the number of groups
|
19 |
+
llm: nn.Module
|
20 |
+
projectors: list[Projector]
|
21 |
+
u_id_set: set
|
22 |
+
softmax: nn.Module
|
23 |
+
is_partition: bool
|
24 |
+
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
k: int,
|
28 |
+
llm: nn.Module,
|
29 |
+
projectors: list[Projector],
|
30 |
+
softmax: nn.Module,
|
31 |
+
is_partition: bool=False,
|
32 |
+
):
|
33 |
+
super().__init__()
|
34 |
+
|
35 |
+
self.k = k
|
36 |
+
self.llm = llm
|
37 |
+
self.softmax = softmax
|
38 |
+
# init user_id registration table and user weights dictionary
|
39 |
+
self.u_id_set = set()
|
40 |
+
self.W = nn.ParameterDict()
|
41 |
+
self.tmp_store_user_ideal_points = None
|
42 |
+
# register all k projectors in the moduledict
|
43 |
+
assert len(projectors) == k, f"The num of projectors should match up with num of groups: {k} != {len(projectors)}"
|
44 |
+
self.projectors = nn.ModuleDict()
|
45 |
+
for i in range(k):
|
46 |
+
self.projectors[str(i)] = projectors[i]
|
47 |
+
self.is_partition = is_partition
|
48 |
+
|
49 |
+
def init_weight(self, u_ids:list, reinit:bool=False):
|
50 |
+
for u_id in u_ids:
|
51 |
+
if u_id not in self.u_id_set or reinit:
|
52 |
+
self.W[u_id] = nn.Parameter(
|
53 |
+
torch.randn((self.k), dtype=torch.float32),
|
54 |
+
requires_grad=True,
|
55 |
+
).to(next(self.projectors[str(0)].parameters()).device)
|
56 |
+
self.u_id_set.add(u_id)
|
57 |
+
else:
|
58 |
+
logger.warning('π wait? same user?')
|
59 |
+
|
60 |
+
def get_sfx_w(self, u_ids:list):
|
61 |
+
w = torch.stack([self.W[key] for key in u_ids], dim=0) # (bs, k)
|
62 |
+
w = self.softmax(w)
|
63 |
+
return w
|
64 |
+
|
65 |
+
def get_hardmax_w(self, u_ids:list):
|
66 |
+
w = torch.stack([self.W[key] for key in u_ids], dim=0)
|
67 |
+
w = F.one_hot(w.argmax(dim=1), num_classes=self.k).float() # (bs, k)
|
68 |
+
return w
|
69 |
+
|
70 |
+
def infer_gk(self, prompt_tokens, rm_cached=None):
|
71 |
+
'''
|
72 |
+
prompt_tokens: {'input_ids': torch.tensor, 'attention_mask': torch.tensor}
|
73 |
+
If you want to activate rm_cached, please pass in the rm_cached dict or empty dict.
|
74 |
+
'''
|
75 |
+
input_ids = prompt_tokens['input_ids']
|
76 |
+
attention_mask = prompt_tokens['attention_mask']
|
77 |
+
|
78 |
+
if rm_cached is None:
|
79 |
+
embeds = self.llm(
|
80 |
+
input_ids=input_ids,
|
81 |
+
attention_mask=attention_mask,
|
82 |
+
).last_hidden_state
|
83 |
+
else:
|
84 |
+
res = self.llm(
|
85 |
+
input_ids=input_ids[:, -1:],
|
86 |
+
# attention_mask=attention_mask,
|
87 |
+
past_key_values=rm_cached["user_learner"],
|
88 |
+
use_cache=True
|
89 |
+
)
|
90 |
+
rm_cached["user_learner"] = res.past_key_values
|
91 |
+
embeds = res.last_hidden_state
|
92 |
+
|
93 |
+
# embeds shape: (bs, seq_len, hid_dim)
|
94 |
+
shape = embeds.shape
|
95 |
+
embeds = embeds.view(-1, shape[-1]) # (bs*seq_len, hid_dim)
|
96 |
+
# g(embeds) shape: (bs*seq_len, hid_dim) -> (bs*seq_len, pref_dim)
|
97 |
+
logits = torch.stack([g(embeds).view(shape[0], shape[1], -1) for g in self.projectors.values()],dim=1)
|
98 |
+
if rm_cached is None:
|
99 |
+
return logits
|
100 |
+
else:
|
101 |
+
return logits, rm_cached # (bs, k, seq_len, hidden_size)
|
102 |
+
|
103 |
+
def return_user_ideal_points(self):
|
104 |
+
if self.tmp_store_user_ideal_points == None:
|
105 |
+
raise ValueError('No user ideal points stored')
|
106 |
+
return self.tmp_store_user_ideal_points
|
107 |
+
|
108 |
+
def forward(self, prompt_tokens, rm_cached=None): # only pass the prompt tokens
|
109 |
+
'''
|
110 |
+
prompt_tokens: {'input_ids': torch.tensor, 'attention_mask': torch.tensor}
|
111 |
+
'''
|
112 |
+
if rm_cached is None:
|
113 |
+
prompt_logits = self.infer_gk(prompt_tokens)
|
114 |
+
else:
|
115 |
+
prompt_logits, rm_cached = self.infer_gk(prompt_tokens, rm_cached)
|
116 |
+
bs = prompt_tokens['input_ids'].size(0)
|
117 |
+
assert sum(mix_weight) == 1
|
118 |
+
# w = self.softmax(mix_weight.repeat(bs, 1))
|
119 |
+
w = mix_weight.repeat(bs, 1)
|
120 |
+
logger.info(f"{w=}")
|
121 |
+
logger.info(f"{w.shape=}")
|
122 |
+
w = w.unsqueeze(-1).unsqueeze(-1)
|
123 |
+
y_hat = (w * prompt_logits).sum(dim=1)
|
124 |
+
self.tmp_store_user_ideal_points = y_hat
|
125 |
+
return y_hat, rm_cached
|
126 |
+
|
127 |
+
def eval(self):
|
128 |
+
super().eval()
|
129 |
+
if self.is_partition:
|
130 |
+
warnings.warn("π€ UserPromptLearner(Partition version) is in eval mode: argmax")
|
131 |
+
self.is_argmax = True
|
132 |
+
else:
|
133 |
+
warnings.warn("π€ UserPromptLearner(Mixture version) is in eval mode: sfx")
|
134 |
+
self.is_argmax = False
|
135 |
+
|
136 |
+
def train(self, mode: bool = True):
|
137 |
+
super().train(mode)
|
138 |
+
if mode:
|
139 |
+
if self.is_partition:
|
140 |
+
warnings.warn("π€ UserPromptLearner(Partition version) is in train mode: sfx")
|
141 |
+
self.is_argmax = False
|
142 |
+
else:
|
143 |
+
warnings.warn("π€ UserPromptLearner(Mixture version) is in train mode: sfx")
|
144 |
+
self.is_argmax = False
|
145 |
+
else:
|
146 |
+
if self.is_partition:
|
147 |
+
warnings.warn("π€ UserPromptLearner(Partition version) is in eval mode: argmax")
|
148 |
+
self.is_argmax = True
|
149 |
+
else:
|
150 |
+
warnings.warn("π€ UserPromptLearner(Mixture version) is in eval mode: sfx")
|
151 |
+
self.is_argmax = False
|