Upload PAL_B_RM_opt
Browse files- learner.py +1 -2
- userLearner.py +1 -1
learner.py
CHANGED
@@ -122,9 +122,8 @@ class PrefLearner(BasePrefLearner): # <f(x),f(u)>
|
|
122 |
# logger.critical(f"{prompt_prime[0]=}")
|
123 |
# logger.critical(f"{items_prime.shape=}")
|
124 |
# logger.critical(f"{prompt_prime.shape=}")
|
125 |
-
# FIXME: bug exist here
|
126 |
if self.pref_learner_type == 'angle':
|
127 |
-
#
|
128 |
prompt_last_prime = prompt_prime[:, -1, :]
|
129 |
prompt_last_prime = prompt_last_prime.unsqueeze(1)
|
130 |
prompt_last_prime = prompt_last_prime / torch.norm(prompt_last_prime, dim=-1, keepdim=True)
|
|
|
122 |
# logger.critical(f"{prompt_prime[0]=}")
|
123 |
# logger.critical(f"{items_prime.shape=}")
|
124 |
# logger.critical(f"{prompt_prime.shape=}")
|
|
|
125 |
if self.pref_learner_type == 'angle':
|
126 |
+
# NOTICE: here we implement the "last token only" version of PAL-B
|
127 |
prompt_last_prime = prompt_prime[:, -1, :]
|
128 |
prompt_last_prime = prompt_last_prime.unsqueeze(1)
|
129 |
prompt_last_prime = prompt_last_prime / torch.norm(prompt_last_prime, dim=-1, keepdim=True)
|
userLearner.py
CHANGED
@@ -92,7 +92,7 @@ class UserLearner(nn.Module):
|
|
92 |
|
93 |
# embeds shape: (bs, seq_len, hid_dim)
|
94 |
shape = embeds.shape
|
95 |
-
# only last hidden state start
|
96 |
embeds = embeds[:, -1, :] # (bs, seq_len, hid_dim) -> (bs, hid_dim)
|
97 |
embeds = embeds.unsqueeze(1).repeat(1, shape[1], 1) # (bs, hid_dim) -> (bs, seq_len, hid_dim)
|
98 |
# only last hidden state end
|
|
|
92 |
|
93 |
# embeds shape: (bs, seq_len, hid_dim)
|
94 |
shape = embeds.shape
|
95 |
+
# only last hidden state start (only use the last token of the prompt)
|
96 |
embeds = embeds[:, -1, :] # (bs, seq_len, hid_dim) -> (bs, hid_dim)
|
97 |
embeds = embeds.unsqueeze(1).repeat(1, shape[1], 1) # (bs, hid_dim) -> (bs, seq_len, hid_dim)
|
98 |
# only last hidden state end
|