daiweichen commited on
Commit
51fef75
·
verified ·
1 Parent(s): e0ba201

Upload PAL_B_RM_opt

Browse files
Files changed (5) hide show
  1. README.md +4 -4
  2. itemLearner.py +4 -1
  3. learner.py +9 -7
  4. pytorch_model.bin +1 -1
  5. userLearner.py +7 -2
README.md CHANGED
@@ -1,12 +1,12 @@
1
  ---
2
- library_name: transformers
3
- license: mit
4
  datasets:
5
  - CarperAI/openai_summarize_tldr
6
  language:
7
  - en
8
- base_model:
9
- - facebook/opt-350m
10
  ---
11
 
12
  # Model Card for Model ID
 
1
  ---
2
+ base_model:
3
+ - facebook/opt-350m
4
  datasets:
5
  - CarperAI/openai_summarize_tldr
6
  language:
7
  - en
8
+ library_name: transformers
9
+ license: mit
10
  ---
11
 
12
  # Model Card for Model ID
itemLearner.py CHANGED
@@ -27,6 +27,8 @@ class ItemLearner(nn.Module):
27
  '''
28
  input_ids = x['input_ids']
29
  attention_mask = x['attention_mask']
 
 
30
 
31
  if rm_cached is None:
32
  llm_res = self.llm(
@@ -37,11 +39,12 @@ class ItemLearner(nn.Module):
37
  llm_res = self.llm(
38
  input_ids=input_ids[:, -1:], # attention_mask=attention_mask,
39
  past_key_values=rm_cached["item_learner"],
40
- use_cache=False
41
  )
42
  rm_cached["item_learner"] = llm_res.past_key_values
43
 
44
  embeds = llm_res.last_hidden_state
 
45
  # embeds shape: (bs, seq_len, hidden_size)
46
  shape = embeds.shape
47
  embeds = embeds.view(-1, shape[-1]) # (bs*seq_len, hidden_size)
 
27
  '''
28
  input_ids = x['input_ids']
29
  attention_mask = x['attention_mask']
30
+ # logger.critical(f"ItemLearner: {input_ids=}")
31
+ # logger.critical(f"ItemLearner: {attention_mask=}")
32
 
33
  if rm_cached is None:
34
  llm_res = self.llm(
 
39
  llm_res = self.llm(
40
  input_ids=input_ids[:, -1:], # attention_mask=attention_mask,
41
  past_key_values=rm_cached["item_learner"],
42
+ use_cache=True
43
  )
44
  rm_cached["item_learner"] = llm_res.past_key_values
45
 
46
  embeds = llm_res.last_hidden_state
47
+ # logger.critical(f"ItemLearner: {embeds=}")
48
  # embeds shape: (bs, seq_len, hidden_size)
49
  shape = embeds.shape
50
  embeds = embeds.view(-1, shape[-1]) # (bs*seq_len, hidden_size)
learner.py CHANGED
@@ -113,16 +113,18 @@ class PrefLearner(BasePrefLearner): # <f(x),f(u)>
113
 
114
  def forward(self, x, rm_cached=None):
115
  assert self.uid is not None, "Please specify the user id first by calling specify_user_ids() to personalize the reward model"
116
- items, prompt = x
117
  if rm_cached is None:
118
  items_prime, prompt_prime = self.map_to_pref_embedding_space((self.uid, prompt, items))
119
  else:
120
  items_prime, prompt_prime, rm_cached = self.map_to_pref_embedding_space((self.uid, prompt, items), rm_cached)
121
- logger.info(f"{items_prime[0]=}")
122
- logger.info(f"{prompt_prime[0]=}")
123
- logger.info(f"{items_prime.shape=}")
124
- logger.info(f"{prompt_prime.shape=}")
 
125
  if self.pref_learner_type == 'angle':
 
126
  prompt_last_prime = prompt_prime[:, -1, :]
127
  prompt_last_prime = prompt_last_prime.unsqueeze(1)
128
  prompt_last_prime = prompt_last_prime / torch.norm(prompt_last_prime, dim=-1, keepdim=True)
@@ -131,8 +133,8 @@ class PrefLearner(BasePrefLearner): # <f(x),f(u)>
131
  items_last_prime = items_last_prime / torch.norm(items_last_prime, dim=-1, keepdim=True)
132
  logit_scale = self.logit_scale.exp()
133
  clamped_logit_scale = torch.clamp(logit_scale, max=100)
134
- logger.info(f"{prompt_last_prime.shape=}")
135
- logger.info(f"{items_last_prime.shape=}")
136
  sim_score = (prompt_last_prime * items_last_prime).sum(dim=-1) * clamped_logit_scale # (bs, max_token_length)
137
  if rm_cached is None:
138
  return sim_score
 
113
 
114
  def forward(self, x, rm_cached=None):
115
  assert self.uid is not None, "Please specify the user id first by calling specify_user_ids() to personalize the reward model"
116
+ prompt, items = x
117
  if rm_cached is None:
118
  items_prime, prompt_prime = self.map_to_pref_embedding_space((self.uid, prompt, items))
119
  else:
120
  items_prime, prompt_prime, rm_cached = self.map_to_pref_embedding_space((self.uid, prompt, items), rm_cached)
121
+ # logger.critical(f"{items_prime[0]=}")
122
+ # logger.critical(f"{prompt_prime[0]=}")
123
+ # logger.critical(f"{items_prime.shape=}")
124
+ # logger.critical(f"{prompt_prime.shape=}")
125
+ # FIXME: bug exist here
126
  if self.pref_learner_type == 'angle':
127
+ # FIXME: do the cumulative evaluation!
128
  prompt_last_prime = prompt_prime[:, -1, :]
129
  prompt_last_prime = prompt_last_prime.unsqueeze(1)
130
  prompt_last_prime = prompt_last_prime / torch.norm(prompt_last_prime, dim=-1, keepdim=True)
 
133
  items_last_prime = items_last_prime / torch.norm(items_last_prime, dim=-1, keepdim=True)
134
  logit_scale = self.logit_scale.exp()
135
  clamped_logit_scale = torch.clamp(logit_scale, max=100)
136
+ # logger.critical(f"{prompt_last_prime.shape=}")
137
+ # logger.critical(f"{items_last_prime.shape=}")
138
  sim_score = (prompt_last_prime * items_last_prime).sum(dim=-1) * clamped_logit_scale # (bs, max_token_length)
139
  if rm_cached is None:
140
  return sim_score
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:770877c170b8b51c6e6555de658213f0a6a1fca5c74370f1b8fed47cf6411bac
3
  size 1334487698
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8c5e8e5083c6c333b9ba3284e989dab95ef00f70e1b97770df255acebada4388
3
  size 1334487698
userLearner.py CHANGED
@@ -92,6 +92,11 @@ class UserLearner(nn.Module):
92
 
93
  # embeds shape: (bs, seq_len, hid_dim)
94
  shape = embeds.shape
 
 
 
 
 
95
  embeds = embeds.view(-1, shape[-1]) # (bs*seq_len, hid_dim)
96
  # g(embeds) shape: (bs*seq_len, hid_dim) -> (bs*seq_len, pref_dim)
97
  logits = torch.stack([g(embeds).view(shape[0], shape[1], -1) for g in self.projectors.values()],dim=1)
@@ -118,8 +123,8 @@ class UserLearner(nn.Module):
118
  # assert sum(mix_weight) == 1
119
  # w = self.softmax(mix_weight.repeat(bs, 1))
120
  # w = mix_weight.repeat(bs, 1)
121
- logger.info(f"{w=}")
122
- logger.info(f"{w.shape=}")
123
  w = w.unsqueeze(-1).unsqueeze(-1)
124
  y_hat = (w * prompt_logits).sum(dim=1)
125
  self.tmp_store_user_ideal_points = y_hat
 
92
 
93
  # embeds shape: (bs, seq_len, hid_dim)
94
  shape = embeds.shape
95
+ # only last hidden state start
96
+ embeds = embeds[:, -1, :] # (bs, seq_len, hid_dim) -> (bs, hid_dim)
97
+ embeds = embeds.unsqueeze(1).repeat(1, shape[1], 1) # (bs, hid_dim) -> (bs, seq_len, hid_dim)
98
+ # only last hidden state end
99
+ # logger.critical("using only last hidden state of prompt tokens")
100
  embeds = embeds.view(-1, shape[-1]) # (bs*seq_len, hid_dim)
101
  # g(embeds) shape: (bs*seq_len, hid_dim) -> (bs*seq_len, pref_dim)
102
  logits = torch.stack([g(embeds).view(shape[0], shape[1], -1) for g in self.projectors.values()],dim=1)
 
123
  # assert sum(mix_weight) == 1
124
  # w = self.softmax(mix_weight.repeat(bs, 1))
125
  # w = mix_weight.repeat(bs, 1)
126
+ # logger.info(f"{w=}")
127
+ # logger.info(f"{w.shape=}")
128
  w = w.unsqueeze(-1).unsqueeze(-1)
129
  y_hat = (w * prompt_logits).sum(dim=1)
130
  self.tmp_store_user_ideal_points = y_hat