daiweichen commited on
Commit
002a82b
·
verified ·
1 Parent(s): 51fef75

Upload PAL_B_RM_opt

Browse files
Files changed (2) hide show
  1. learner.py +1 -2
  2. userLearner.py +1 -1
learner.py CHANGED
@@ -122,9 +122,8 @@ class PrefLearner(BasePrefLearner): # <f(x),f(u)>
122
  # logger.critical(f"{prompt_prime[0]=}")
123
  # logger.critical(f"{items_prime.shape=}")
124
  # logger.critical(f"{prompt_prime.shape=}")
125
- # FIXME: bug exist here
126
  if self.pref_learner_type == 'angle':
127
- # FIXME: do the cumulative evaluation!
128
  prompt_last_prime = prompt_prime[:, -1, :]
129
  prompt_last_prime = prompt_last_prime.unsqueeze(1)
130
  prompt_last_prime = prompt_last_prime / torch.norm(prompt_last_prime, dim=-1, keepdim=True)
 
122
  # logger.critical(f"{prompt_prime[0]=}")
123
  # logger.critical(f"{items_prime.shape=}")
124
  # logger.critical(f"{prompt_prime.shape=}")
 
125
  if self.pref_learner_type == 'angle':
126
+ # NOTICE: here we implement the "last token only" version of PAL-B
127
  prompt_last_prime = prompt_prime[:, -1, :]
128
  prompt_last_prime = prompt_last_prime.unsqueeze(1)
129
  prompt_last_prime = prompt_last_prime / torch.norm(prompt_last_prime, dim=-1, keepdim=True)
userLearner.py CHANGED
@@ -92,7 +92,7 @@ class UserLearner(nn.Module):
92
 
93
  # embeds shape: (bs, seq_len, hid_dim)
94
  shape = embeds.shape
95
- # only last hidden state start
96
  embeds = embeds[:, -1, :] # (bs, seq_len, hid_dim) -> (bs, hid_dim)
97
  embeds = embeds.unsqueeze(1).repeat(1, shape[1], 1) # (bs, hid_dim) -> (bs, seq_len, hid_dim)
98
  # only last hidden state end
 
92
 
93
  # embeds shape: (bs, seq_len, hid_dim)
94
  shape = embeds.shape
95
+ # only last hidden state start (only use the last token of the prompt)
96
  embeds = embeds[:, -1, :] # (bs, seq_len, hid_dim) -> (bs, hid_dim)
97
  embeds = embeds.unsqueeze(1).repeat(1, shape[1], 1) # (bs, hid_dim) -> (bs, seq_len, hid_dim)
98
  # only last hidden state end