daiweichen commited on
Commit
3a2aa34
Β·
verified Β·
1 Parent(s): 5da6518

Upload PAL_B_RM_opt

Browse files
config.json CHANGED
@@ -2,6 +2,10 @@
2
  "architectures": [
3
  "PAL_B_RM_opt"
4
  ],
 
 
 
 
5
  "d_hid": 512,
6
  "d_pref": 512,
7
  "initializer_type": "gaussian",
 
2
  "architectures": [
3
  "PAL_B_RM_opt"
4
  ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_pal_b_rm.PAL_B_Config",
7
+ "AutoModel": "modeling_pal_b_rm.PAL_B_RM_opt"
8
+ },
9
  "d_hid": 512,
10
  "d_pref": 512,
11
  "initializer_type": "gaussian",
configuration_pal_b_rm.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class PAL_B_Config(PretrainedConfig):
4
+ model_type = "facebook/opt" # opt family model aligned PAL reward model
5
+
6
+ def __init__(
7
+ self,
8
+ d_hid: int = 512,
9
+ d_pref: int = 512,
10
+ k: int = 2,
11
+ llm_name: str = "facebook/opt-350m",
12
+ pref_learner_type: str = "angle",
13
+ proj_arch: str = "mlp2-gelu-dropout0",
14
+ initializer_type: str = "gaussian",
15
+ is_expectation_norm_init: bool = False,
16
+ sfx_type: str = "softmax",
17
+ sfx_temperature: float = 1.0,
18
+ is_temperature_learnable: bool = False,
19
+ is_gumbel_hard: bool = None,
20
+ **kwargs,
21
+ ):
22
+ self.d_hid = d_hid
23
+ self.d_pref = d_pref
24
+ self.k = k
25
+ self.llm_name = llm_name
26
+ self.pref_learner_type = pref_learner_type
27
+ self.proj_arch = proj_arch
28
+ self.initializer_type = initializer_type
29
+ self.is_expectation_norm_init = is_expectation_norm_init
30
+ self.sfx_type = sfx_type
31
+ self.sfx_temperature = sfx_temperature
32
+ self.is_temperature_learnable = is_temperature_learnable
33
+ self.is_gumbel_hard = is_gumbel_hard
34
+ super().__init__(**kwargs)
35
+
connector.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .projector import Projector
2
+ import torch.nn as nn
3
+ import re
4
+
5
+
6
+ # class Connector(nn.Module):
7
+ # def __init__(self, cnct_arch:str, in_dims:int, out_dims:int):
8
+ # super().__init__()
9
+ # # projector_type structure ["mlp?-relu-dropout?-residual","identity"]
10
+ # self.cnct_arch = cnct_arch
11
+
12
+ # if cnct_arch == 'identity':
13
+ # self.m = nn.Identity()
14
+
15
+ # pattern = r"mlp(\d+)-(relu|gelu|linear)-dropout(\d+)?(-residual-batchnorm|-batchnorm-residual|-residual|-batchnorm|-nobias)?"
16
+ # match = re.match(pattern, cnct_arch)
17
+
18
+ # if match:
19
+ # layers = int(match.group(1))
20
+ # act = match.group(2)
21
+ # dropout_p = int(match.group(3))
22
+ # num_digit = len(match.group(3))
23
+ # dropout_p = dropout_p / 10**num_digit
24
+ # # print("match.group(4): ", match.group(4))
25
+ # nobias = False
26
+ # if match.group(4) != None:
27
+ # residual = True if ("-residual" in match.group(4)) else False
28
+ # batchnorm = True if ("-batchnorm" in match.group(4)) else False
29
+ # nobias = True if ("-nobias" in match.group(4)) else False
30
+ # else:
31
+ # residual = False
32
+ # batchnorm = False
33
+ # latent_dims = [out_dims] * layers
34
+ # self.m = Projector(
35
+ # in_dims=in_dims,
36
+ # out_dims=out_dims,
37
+ # latent_dims=latent_dims,
38
+ # bias=not nobias,
39
+ # dropout_p=dropout_p,
40
+ # activation=act,
41
+ # identity_map=residual,
42
+ # use_batchnorm=batchnorm,
43
+ # )
44
+
45
+ # def forward(self,x):
46
+ # return self.m(x)
47
+
48
+
49
+ class Connector(nn.Module):
50
+ def __init__(self, in_dims: int, out_dims: int, cnct_arch:str):
51
+ super().__init__()
52
+ pattern = r"mlp(\d+)-(relu|gelu|linear)-dropout(\d+)?(-residual-batchnorm|-batchnorm-residual|-residual|-batchnorm|-nobias)?"
53
+ match = re.match(pattern, cnct_arch)
54
+ if match:
55
+ layers = int(match.group(1))
56
+ act = match.group(2)
57
+ dropout_p = int(match.group(3))
58
+ num_digit = len(match.group(3))
59
+ dropout_p = dropout_p / 10**num_digit
60
+ if match.group(4) != None:
61
+ residual = True if ("-residual" in match.group(4)) else False
62
+ batchnorm = True if ("-batchnorm" in match.group(4)) else False
63
+ nobias = True if ("-nobias" in match.group(4)) else False
64
+ else:
65
+ residual = False
66
+ batchnorm = False
67
+ nobias = False
68
+ latent_dims = [out_dims] * layers
69
+ self.mlp = Projector(
70
+ in_dims=in_dims,
71
+ out_dims=out_dims,
72
+ latent_dims=latent_dims,
73
+ bias=not nobias,
74
+ dropout_p=dropout_p,
75
+ activation=act,
76
+ identity_map=residual,
77
+ use_batchnorm=batchnorm,
78
+ )
79
+ elif cnct_arch == 'identity':
80
+ self.mlp = nn.Identity()
81
+ else:
82
+ raise ValueError(f'no such connection architecture {cnct_arch}')
83
+
84
+ def __call__(self, x):
85
+ ret = self.mlp(x)
86
+ return ret
87
+
88
+ if __name__ == "__main__":
89
+ m = Connector(cnct_arch='identity',in_dims=4096,out_dims=768)
90
+ print(m)
91
+ m = Connector(cnct_arch='mlp1-relu-dropout2-residual',in_dims=4096,out_dims=768)
92
+ print(m)
93
+ m = Connector(cnct_arch='mlp1-relu-dropout2-batchnorm',in_dims=4096,out_dims=768)
94
+ print(m)
95
+ m = Connector(cnct_arch='mlp1-relu-dropout2-residual-batchnorm',in_dims=4096,out_dims=768)
96
+ print(m)
97
+ m = Connector(cnct_arch='mlp3-gelu-dropout2',in_dims=4096,out_dims=768)
98
+ print(m)
99
+ m = Connector(cnct_arch='mlp16-relu-dropout75',in_dims=4096,out_dims=768)
100
+ print(m)
101
+ m = Connector(cnct_arch='mlp0-linear-dropout0', in_dims=4096, out_dims=768)
102
+ print(m)
103
+ m = Connector(cnct_arch='mlp0-linear-dropout0-nobias', in_dims=4096, out_dims=768)
104
+ print(m)
105
+ m = Connector(cnct_arch='mlp2-linear-dropout0-nobias', in_dims=4096, out_dims=768)
106
+ print(m)
107
+
108
+ m = Connector(cnct_arch='mlp2-gelu-dropout0', in_dims=512, out_dims=512)
109
+ count = 0
110
+ for p in m.parameters():
111
+ count += p.numel()
112
+ print(count)
custom_sfx.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from typing import Literal, Optional
6
+
7
+ import logging
8
+ logging.basicConfig(level=logging.WARNING, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
9
+ logger = logging.getLogger(__name__)
10
+
11
+ class CustomSoftMax(nn.Module):
12
+ def __init__(
13
+ self,
14
+ sfx_type: Literal['gumbel_softmax', 'softmax'],
15
+ temperature: float,
16
+ is_temperature_learnable: bool,
17
+ is_gumbel_hard: Optional[bool]=None, # [True/False]
18
+ *args,
19
+ **kwargs,
20
+ ) -> None:
21
+
22
+ super().__init__()
23
+ self.sfx_type = sfx_type
24
+ assert not is_temperature_learnable, 'is_temperature_learnable is prohibited in this version, will go to negative'
25
+ self.temperature = nn.Parameter(torch.tensor([float(temperature)]),requires_grad=is_temperature_learnable)
26
+ self.is_gumbel_hard = is_gumbel_hard
27
+ self.args = args
28
+ self.kwargs = kwargs
29
+
30
+ def forward(self, x):
31
+ # x: (bs, dims)
32
+ if self.sfx_type == 'gumbel_softmax':
33
+ if self.is_gumbel_hard is not None:
34
+ return F.gumbel_softmax(x, tau=self.temperature, hard=self.is_gumbel_hard, dim=1)
35
+ else:
36
+ raise ValueError('is_gumbel_hard is not passed')
37
+ elif self.sfx_type == 'softmax':
38
+ return F.softmax(x/self.temperature, dim=1)
39
+ else:
40
+ raise NotImplementedError(f'{self.sfx_type} is not implemented yet')
41
+
42
+ if __name__ == "__main__":
43
+
44
+ sfx = CustomSoftMax(sfx_type='gumbel_softmax', temperature=1, is_temperature_learnable=False, is_gumbel_hard=True)
45
+ x = torch.randn(10,3) # (bs, dims)
46
+ print(x.shape)
47
+ print(sfx(x))
48
+
49
+ sfx = CustomSoftMax(sfx_type='gumbel_softmax', temperature=1, is_temperature_learnable=True, is_gumbel_hard=True)
50
+ x = torch.randn(10,3) # (bs, dims)
51
+ print(x.shape)
52
+ print(sfx(x))
53
+
54
+ sfx = CustomSoftMax(sfx_type='softmax', temperature=1, is_temperature_learnable=False)
55
+ x = torch.randn(10,3)
56
+ print(sfx(x))
57
+
58
+ sfx = CustomSoftMax(sfx_type='softmax',temperature=0.01, is_temperature_learnable=True, is_gumbel_hard=None)
59
+ x = torch.randn(10,3)
60
+ print(sfx(x))
61
+
itemLearner.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from .connector import Connector
5
+ from .projector import Projector
6
+ from .tensor_merger import TensorMerger
7
+ import numpy as np
8
+
9
+ from typing import Literal, Optional, Tuple
10
+ import logging
11
+ logging.basicConfig(level=logging.WARNING, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
12
+ logger = logging.getLogger(__name__)
13
+
14
+ class ItemLearner(nn.Module):
15
+
16
+ llm: nn.Module
17
+ projector: nn.Module
18
+
19
+ def __init__(self, llm, projector):
20
+ super().__init__()
21
+ self.llm = llm
22
+ self.projector = projector
23
+
24
+ def forward(self, x, rm_cached=None): # only pass the generated data
25
+ '''
26
+ x = {'input_ids': torch.tensor, 'attention_mask': torch.tensor}
27
+ '''
28
+ input_ids = x['input_ids']
29
+ attention_mask = x['attention_mask']
30
+
31
+ if rm_cached is None:
32
+ llm_res = self.llm(
33
+ input_ids=input_ids,
34
+ attention_mask=attention_mask,
35
+ )
36
+ else:
37
+ llm_res = self.llm(
38
+ input_ids=input_ids[:, -1:], # attention_mask=attention_mask,
39
+ past_key_values=rm_cached["item_learner"],
40
+ use_cache=False
41
+ )
42
+ rm_cached["item_learner"] = llm_res.past_key_values
43
+
44
+ embeds = llm_res.last_hidden_state
45
+ # embeds shape: (bs, seq_len, hidden_size)
46
+ shape = embeds.shape
47
+ embeds = embeds.view(-1, shape[-1]) # (bs*seq_len, hidden_size)
48
+ projected_embeds = self.projector(embeds)
49
+
50
+ if rm_cached is None:
51
+ return projected_embeds.view(shape[0], shape[1], -1)
52
+ else:
53
+ return projected_embeds.view(shape[0], shape[1], -1), rm_cached
learner.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*-coding:utf-8 -*-
3
+
4
+ '''
5
+ @Desc: This is the implementation of PAL-B
6
+ '''
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from transformers import AutoModel, AutoConfig
13
+
14
+ from .connector import Connector
15
+ from .tensor_initializer import TensorInitializer
16
+ from .custom_sfx import CustomSoftMax
17
+ from .itemLearner import ItemLearner
18
+ from .userLearner import UserLearner
19
+
20
+ from collections import defaultdict
21
+ from typing import Literal, Optional, Tuple
22
+
23
+ import logging
24
+ logger = logging.getLogger(__name__)
25
+
26
+ class BasePrefLearner(nn.Module):
27
+ def __init__(
28
+ self,
29
+ d_hid: int,
30
+ d_pref: int,
31
+ k: int,
32
+ llm_name: str,
33
+ pref_learner_type: Literal["dist","dist_normalization","angle","norm","dist_logistic","angle_hinge"],
34
+ proj_arch: str,
35
+ initializer_type: Literal["gaussian"],
36
+ is_expectation_norm_init: bool, # the tensor initialization parameters
37
+ sfx_type: Literal["gumbel_softmax", "softmax"],
38
+ sfx_temperature: float,
39
+ is_temperature_learnable: bool,
40
+ is_gumbel_hard: Optional[bool]=None,
41
+ is_partition: bool=False,
42
+ seed: int=42,
43
+ **kwargs
44
+ ):
45
+ super().__init__()
46
+ self.pref_learner_type = pref_learner_type
47
+ self.is_temperature_learnable = is_temperature_learnable
48
+ # init all necessary modules
49
+ model_config = AutoConfig.from_pretrained(llm_name)
50
+ self.llm = AutoModel.from_pretrained(llm_name,from_tf=bool(".ckpt" in llm_name),config=model_config)
51
+ self.tensor_initializer = TensorInitializer(initializer_type, seed, is_expectation_norm_init=is_expectation_norm_init)
52
+ self.projector_f = Connector(cnct_arch=proj_arch,in_dims=d_hid,out_dims=d_pref)
53
+ self.projectors_gk = [Connector(cnct_arch=proj_arch,in_dims=d_hid,out_dims=d_pref) for _ in range(k)]
54
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
55
+ self.softmax_w = CustomSoftMax(sfx_type=sfx_type,
56
+ temperature=sfx_temperature,
57
+ is_temperature_learnable=is_temperature_learnable,
58
+ is_gumbel_hard=is_gumbel_hard)
59
+ self.item_learner = ItemLearner(
60
+ llm = self.llm,
61
+ projector=self.projector_f,
62
+ )
63
+ self.is_partition = is_partition
64
+ self.user_learner = UserLearner(k=k, llm=self.llm, projectors=self.projectors_gk, softmax=self.softmax_w, is_partition=is_partition)
65
+ logger.critical('πŸ›‘ Remember to call update_trainable_params() after the model is initialized.')
66
+
67
+ def update_trainable_params(self, fix_modules: Tuple[str,...]=()):
68
+ # capture params
69
+ self.trainable_params = defaultdict(list)
70
+ if "llm" not in fix_modules:
71
+ self.trainable_params["llm"] = self.llm.parameters()
72
+ else:
73
+ self.llm.eval()
74
+ if "itemLearnerProjector" not in fix_modules:
75
+ self.trainable_params["projector_f"].extend(self.item_learner.projector.parameters())
76
+ if "userLearnerProjector" not in fix_modules:
77
+ self.trainable_params["projectors_gk"].extend(list(self.user_learner.projectors.parameters()))
78
+ if "W" not in fix_modules:
79
+ self.trainable_params["W"] = self.user_learner.W.parameters()
80
+ if self.pref_learner_type in ["angle","dist_logistic"] and "logit_scale" not in fix_modules:
81
+ self.trainable_params["logit_scale"] = self.logit_scale
82
+ if self.is_temperature_learnable and "temperature" not in fix_modules:
83
+ self.trainable_params["temperature"] = self.softmax_w.temperature
84
+
85
+ def map_to_pref_embedding_space(self, x, rm_cached=None):
86
+ # ({
87
+ # 'input_ids': prompt_input_ids,\
88
+ # 'attention_mask': prompt_attention_mask,
89
+ # },\
90
+ # {
91
+ # 'input_ids': eval_input_ids,\
92
+ # 'attention_mask': eval_attention_mask,\
93
+ # })
94
+ prompt, items = x
95
+ if rm_cached is None:
96
+ items_prime = self.item_learner(items)
97
+ prompt_prime = self.user_learner(prompt)
98
+ return items_prime, prompt_prime
99
+ else:
100
+ items_prime, rm_cached = self.item_learner(items, rm_cached)
101
+ prompt_prime, rm_cached = self.user_learner(prompt, rm_cached)
102
+ return items_prime, prompt_prime, rm_cached
103
+
104
+
105
+ class PrefLearner(BasePrefLearner): # <f(x),f(u)>
106
+
107
+ def __init__(self,*args, **kwargs):
108
+ super().__init__(*args, **kwargs)
109
+
110
+ def forward(self, x, rm_cached=None):
111
+ items, prompt = x
112
+ if rm_cached is None:
113
+ items_prime, prompt_prime = self.map_to_pref_embedding_space((prompt, items))
114
+ else:
115
+ items_prime, prompt_prime, rm_cached = self.map_to_pref_embedding_space((prompt, items), rm_cached)
116
+ logger.info(f"{items_prime[0]=}")
117
+ logger.info(f"{prompt_prime[0]=}")
118
+ logger.info(f"{items_prime.shape=}")
119
+ logger.info(f"{prompt_prime.shape=}")
120
+ if self.pref_learner_type == 'angle':
121
+ prompt_last_prime = prompt_prime[:, -1, :]
122
+ prompt_last_prime = prompt_last_prime.unsqueeze(1)
123
+ prompt_last_prime = prompt_last_prime / torch.norm(prompt_last_prime, dim=-1, keepdim=True)
124
+ items_last_prime = items_prime[:, -1, :]
125
+ items_last_prime = items_last_prime.unsqueeze(1)
126
+ items_last_prime = items_last_prime / torch.norm(items_last_prime, dim=-1, keepdim=True)
127
+ logit_scale = self.logit_scale.exp()
128
+ clamped_logit_scale = torch.clamp(logit_scale, max=100)
129
+ logger.info(f"{prompt_last_prime.shape=}")
130
+ logger.info(f"{items_last_prime.shape=}")
131
+ sim_score = (prompt_last_prime * items_last_prime).sum(dim=-1) * clamped_logit_scale # (bs, max_token_length)
132
+ if rm_cached is None:
133
+ return sim_score
134
+ else:
135
+ return sim_score, rm_cached
136
+ else:
137
+ raise NotImplementedError
138
+
modeling_pal_b_rm.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+ from .learner import PrefLearner
3
+ from .configuration_pal_b_rm import PAL_B_Config
4
+
5
+ class PAL_B_RM_opt(PreTrainedModel):
6
+ config_class = PAL_B_Config
7
+
8
+ def __init__(self, config):
9
+ super().__init__(config)
10
+ self.model = PrefLearner(
11
+ d_hid=config.d_hid,
12
+ d_pref=config.d_pref,
13
+ k=config.k,
14
+ llm_name=config.llm_name,
15
+ pref_learner_type=config.pref_learner_type,
16
+ proj_arch=config.proj_arch,
17
+ initializer_type=config.initializer_type,
18
+ is_expectation_norm_init=config.is_expectation_norm_init,
19
+ sfx_type=config.sfx_type,
20
+ sfx_temperature=config.sfx_temperature,
21
+ is_temperature_learnable=config.is_temperature_learnable,
22
+ is_gumbel_hard=config.is_gumbel_hard,
23
+ )
24
+ # self.model.user_learner.init_weight(uids)
25
+
26
+ def forward(self, x):
27
+ logits = self.model(x)
28
+ return {'logits': logits}
projector.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Sequence
2
+ import torch.nn as nn
3
+ import torch
4
+
5
+
6
+ class Projector(nn.Module):
7
+
8
+ in_dims: int
9
+ out_dims: int
10
+ latent_dims: Sequence[int]
11
+ bias: bool
12
+ dropout_p: float
13
+ activation: str
14
+ identity_map: bool
15
+ use_batchnorm: bool
16
+
17
+ def __init__(
18
+ self,
19
+ in_dims: int,
20
+ out_dims: int,
21
+ latent_dims: Sequence[int] = tuple([]),
22
+ bias: bool = True,
23
+ dropout_p: float = 0.2,
24
+ activation:str='relu',
25
+ identity_map=False,
26
+ use_batchnorm: bool = False,
27
+ ):
28
+ super().__init__()
29
+
30
+ self.in_dims = in_dims
31
+ self.out_dims = out_dims
32
+ self.bias = bias
33
+ self.dropout_p = dropout_p
34
+ self.latent_dims = latent_dims
35
+ self.act = None
36
+ self.identity_map = identity_map
37
+ self.use_batchnorm = use_batchnorm
38
+
39
+ if activation == 'relu':
40
+ self.act = nn.ReLU
41
+ elif activation == 'gelu':
42
+ self.act = nn.GELU
43
+ elif activation == 'linear':
44
+ self.act = nn.Identity
45
+ else:
46
+ raise ValueError(f'no such activation {activation}')
47
+
48
+ if identity_map == True:
49
+ self.identity = nn.Identity()
50
+ # self.alpha = nn.Parameter(torch.tensor(0.5))
51
+
52
+ layer_dims = [in_dims] + list(latent_dims)
53
+ layers = []
54
+
55
+ for i in range(len(layer_dims) - 1):
56
+ layers.append(nn.Linear(layer_dims[i], layer_dims[i + 1], bias=self.bias))
57
+ if self.use_batchnorm: # Add batch normalization layer if enabled
58
+ layers.append(nn.BatchNorm1d(layer_dims[i + 1]))
59
+ layers.extend([
60
+ nn.Dropout(p=self.dropout_p),
61
+ self.act()
62
+ ])
63
+
64
+ layers.append(nn.Linear(layer_dims[-1], out_dims, bias=self.bias))
65
+ self.layers = nn.Sequential(*layers)
66
+
67
+ def forward(self, x) -> torch.Tensor:
68
+ """Forward pass of the projector model.
69
+
70
+ Args:
71
+ x: The input features.
72
+
73
+ Returns:
74
+ torch.Tensor: The projected features.
75
+
76
+ """
77
+ if self.identity_map:
78
+ x = self.identity(x) + self.layers(x)
79
+ else:
80
+ x = self.layers(x)
81
+ return x
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ac37d5858537b0f20041627543ee7c4022bbf8e069e1941f8dc40d6c282425bb
3
  size 1334487698
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:770877c170b8b51c6e6555de658213f0a6a1fca5c74370f1b8fed47cf6411bac
3
  size 1334487698
tensor_initializer.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ class TensorInitializer:
5
+
6
+ def __init__(self, type: str, seed: int, is_expectation_norm_init: bool = False):
7
+ self.initializer_type = type
8
+ self.rng = np.random.default_rng(seed)
9
+ self.is_expectation_norm_init = is_expectation_norm_init
10
+
11
+ def gaussian_initializer(
12
+ self,
13
+ dim: int,
14
+ size: int,
15
+ ) -> torch.Tensor:
16
+
17
+ mean = np.zeros(dim)
18
+ if self.is_expectation_norm_init:
19
+ # expectation normalization
20
+ cov = 1 / dim * np.eye(dim)
21
+ return torch.tensor(self.rng.multivariate_normal(mean, cov, size), dtype=torch.float32)#.float()
22
+ else:
23
+ # enforced normalization
24
+ cov = np.eye(dim)
25
+ unnorm_tensor = torch.tensor(self.rng.multivariate_normal(mean, cov, size), dtype=torch.float32)#.float()
26
+ return unnorm_tensor / torch.norm(unnorm_tensor, dim=1, keepdim=True)
27
+
28
+ def __call__(self, *args, **kwargs):
29
+ if self.initializer_type == 'gaussian':
30
+ return self.gaussian_initializer(*args, **kwargs)
31
+ else:
32
+ raise ValueError(f'Unknown initializer type: {self.initializer_type}')
tensor_merger.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class TensorMerger:
5
+
6
+ def __init__(self, merger_type) -> None:
7
+ self.merger_type = merger_type
8
+
9
+ def concat(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
10
+ return torch.cat([x, y], dim=1)
11
+
12
+ def __call__(self, x: torch.Tensor, y: torch.Tensor):
13
+ if self.merger_type == 'concat':
14
+ return self.concat(x,y)
15
+ else:
16
+ raise ValueError(f'Unknown merger type: {self.merger_type}')
userLearner.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from .connector import Connector
5
+ from .projector import Projector
6
+ from .tensor_initializer import TensorInitializer
7
+ from .custom_sfx import CustomSoftMax
8
+ import numpy as np
9
+ import warnings
10
+
11
+ from typing import Literal
12
+
13
+ import logging
14
+ logger = logging.getLogger(__name__)
15
+
16
+ class UserLearner(nn.Module):
17
+
18
+ k: int # the number of groups
19
+ llm: nn.Module
20
+ projectors: list[Projector]
21
+ u_id_set: set
22
+ softmax: nn.Module
23
+ is_partition: bool
24
+
25
+ def __init__(
26
+ self,
27
+ k: int,
28
+ llm: nn.Module,
29
+ projectors: list[Projector],
30
+ softmax: nn.Module,
31
+ is_partition: bool=False,
32
+ ):
33
+ super().__init__()
34
+
35
+ self.k = k
36
+ self.llm = llm
37
+ self.softmax = softmax
38
+ # init user_id registration table and user weights dictionary
39
+ self.u_id_set = set()
40
+ self.W = nn.ParameterDict()
41
+ self.tmp_store_user_ideal_points = None
42
+ # register all k projectors in the moduledict
43
+ assert len(projectors) == k, f"The num of projectors should match up with num of groups: {k} != {len(projectors)}"
44
+ self.projectors = nn.ModuleDict()
45
+ for i in range(k):
46
+ self.projectors[str(i)] = projectors[i]
47
+ self.is_partition = is_partition
48
+
49
+ def init_weight(self, u_ids:list, reinit:bool=False):
50
+ for u_id in u_ids:
51
+ if u_id not in self.u_id_set or reinit:
52
+ self.W[u_id] = nn.Parameter(
53
+ torch.randn((self.k), dtype=torch.float32),
54
+ requires_grad=True,
55
+ ).to(next(self.projectors[str(0)].parameters()).device)
56
+ self.u_id_set.add(u_id)
57
+ else:
58
+ logger.warning('πŸ‘‹ wait? same user?')
59
+
60
+ def get_sfx_w(self, u_ids:list):
61
+ w = torch.stack([self.W[key] for key in u_ids], dim=0) # (bs, k)
62
+ w = self.softmax(w)
63
+ return w
64
+
65
+ def get_hardmax_w(self, u_ids:list):
66
+ w = torch.stack([self.W[key] for key in u_ids], dim=0)
67
+ w = F.one_hot(w.argmax(dim=1), num_classes=self.k).float() # (bs, k)
68
+ return w
69
+
70
+ def infer_gk(self, prompt_tokens, rm_cached=None):
71
+ '''
72
+ prompt_tokens: {'input_ids': torch.tensor, 'attention_mask': torch.tensor}
73
+ If you want to activate rm_cached, please pass in the rm_cached dict or empty dict.
74
+ '''
75
+ input_ids = prompt_tokens['input_ids']
76
+ attention_mask = prompt_tokens['attention_mask']
77
+
78
+ if rm_cached is None:
79
+ embeds = self.llm(
80
+ input_ids=input_ids,
81
+ attention_mask=attention_mask,
82
+ ).last_hidden_state
83
+ else:
84
+ res = self.llm(
85
+ input_ids=input_ids[:, -1:],
86
+ # attention_mask=attention_mask,
87
+ past_key_values=rm_cached["user_learner"],
88
+ use_cache=True
89
+ )
90
+ rm_cached["user_learner"] = res.past_key_values
91
+ embeds = res.last_hidden_state
92
+
93
+ # embeds shape: (bs, seq_len, hid_dim)
94
+ shape = embeds.shape
95
+ embeds = embeds.view(-1, shape[-1]) # (bs*seq_len, hid_dim)
96
+ # g(embeds) shape: (bs*seq_len, hid_dim) -> (bs*seq_len, pref_dim)
97
+ logits = torch.stack([g(embeds).view(shape[0], shape[1], -1) for g in self.projectors.values()],dim=1)
98
+ if rm_cached is None:
99
+ return logits
100
+ else:
101
+ return logits, rm_cached # (bs, k, seq_len, hidden_size)
102
+
103
+ def return_user_ideal_points(self):
104
+ if self.tmp_store_user_ideal_points == None:
105
+ raise ValueError('No user ideal points stored')
106
+ return self.tmp_store_user_ideal_points
107
+
108
+ def forward(self, prompt_tokens, rm_cached=None): # only pass the prompt tokens
109
+ '''
110
+ prompt_tokens: {'input_ids': torch.tensor, 'attention_mask': torch.tensor}
111
+ '''
112
+ if rm_cached is None:
113
+ prompt_logits = self.infer_gk(prompt_tokens)
114
+ else:
115
+ prompt_logits, rm_cached = self.infer_gk(prompt_tokens, rm_cached)
116
+ bs = prompt_tokens['input_ids'].size(0)
117
+ assert sum(mix_weight) == 1
118
+ # w = self.softmax(mix_weight.repeat(bs, 1))
119
+ w = mix_weight.repeat(bs, 1)
120
+ logger.info(f"{w=}")
121
+ logger.info(f"{w.shape=}")
122
+ w = w.unsqueeze(-1).unsqueeze(-1)
123
+ y_hat = (w * prompt_logits).sum(dim=1)
124
+ self.tmp_store_user_ideal_points = y_hat
125
+ return y_hat, rm_cached
126
+
127
+ def eval(self):
128
+ super().eval()
129
+ if self.is_partition:
130
+ warnings.warn("πŸ€– UserPromptLearner(Partition version) is in eval mode: argmax")
131
+ self.is_argmax = True
132
+ else:
133
+ warnings.warn("πŸ€– UserPromptLearner(Mixture version) is in eval mode: sfx")
134
+ self.is_argmax = False
135
+
136
+ def train(self, mode: bool = True):
137
+ super().train(mode)
138
+ if mode:
139
+ if self.is_partition:
140
+ warnings.warn("πŸ€– UserPromptLearner(Partition version) is in train mode: sfx")
141
+ self.is_argmax = False
142
+ else:
143
+ warnings.warn("πŸ€– UserPromptLearner(Mixture version) is in train mode: sfx")
144
+ self.is_argmax = False
145
+ else:
146
+ if self.is_partition:
147
+ warnings.warn("πŸ€– UserPromptLearner(Partition version) is in eval mode: argmax")
148
+ self.is_argmax = True
149
+ else:
150
+ warnings.warn("πŸ€– UserPromptLearner(Mixture version) is in eval mode: sfx")
151
+ self.is_argmax = False