Upload PAL_B_RM_opt
Browse files- README.md +4 -4
- itemLearner.py +4 -1
- learner.py +9 -7
- pytorch_model.bin +1 -1
- userLearner.py +7 -2
README.md
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
---
|
2 |
-
|
3 |
-
|
4 |
datasets:
|
5 |
- CarperAI/openai_summarize_tldr
|
6 |
language:
|
7 |
- en
|
8 |
-
|
9 |
-
|
10 |
---
|
11 |
|
12 |
# Model Card for Model ID
|
|
|
1 |
---
|
2 |
+
base_model:
|
3 |
+
- facebook/opt-350m
|
4 |
datasets:
|
5 |
- CarperAI/openai_summarize_tldr
|
6 |
language:
|
7 |
- en
|
8 |
+
library_name: transformers
|
9 |
+
license: mit
|
10 |
---
|
11 |
|
12 |
# Model Card for Model ID
|
itemLearner.py
CHANGED
@@ -27,6 +27,8 @@ class ItemLearner(nn.Module):
|
|
27 |
'''
|
28 |
input_ids = x['input_ids']
|
29 |
attention_mask = x['attention_mask']
|
|
|
|
|
30 |
|
31 |
if rm_cached is None:
|
32 |
llm_res = self.llm(
|
@@ -37,11 +39,12 @@ class ItemLearner(nn.Module):
|
|
37 |
llm_res = self.llm(
|
38 |
input_ids=input_ids[:, -1:], # attention_mask=attention_mask,
|
39 |
past_key_values=rm_cached["item_learner"],
|
40 |
-
use_cache=
|
41 |
)
|
42 |
rm_cached["item_learner"] = llm_res.past_key_values
|
43 |
|
44 |
embeds = llm_res.last_hidden_state
|
|
|
45 |
# embeds shape: (bs, seq_len, hidden_size)
|
46 |
shape = embeds.shape
|
47 |
embeds = embeds.view(-1, shape[-1]) # (bs*seq_len, hidden_size)
|
|
|
27 |
'''
|
28 |
input_ids = x['input_ids']
|
29 |
attention_mask = x['attention_mask']
|
30 |
+
# logger.critical(f"ItemLearner: {input_ids=}")
|
31 |
+
# logger.critical(f"ItemLearner: {attention_mask=}")
|
32 |
|
33 |
if rm_cached is None:
|
34 |
llm_res = self.llm(
|
|
|
39 |
llm_res = self.llm(
|
40 |
input_ids=input_ids[:, -1:], # attention_mask=attention_mask,
|
41 |
past_key_values=rm_cached["item_learner"],
|
42 |
+
use_cache=True
|
43 |
)
|
44 |
rm_cached["item_learner"] = llm_res.past_key_values
|
45 |
|
46 |
embeds = llm_res.last_hidden_state
|
47 |
+
# logger.critical(f"ItemLearner: {embeds=}")
|
48 |
# embeds shape: (bs, seq_len, hidden_size)
|
49 |
shape = embeds.shape
|
50 |
embeds = embeds.view(-1, shape[-1]) # (bs*seq_len, hidden_size)
|
learner.py
CHANGED
@@ -113,16 +113,18 @@ class PrefLearner(BasePrefLearner): # <f(x),f(u)>
|
|
113 |
|
114 |
def forward(self, x, rm_cached=None):
|
115 |
assert self.uid is not None, "Please specify the user id first by calling specify_user_ids() to personalize the reward model"
|
116 |
-
|
117 |
if rm_cached is None:
|
118 |
items_prime, prompt_prime = self.map_to_pref_embedding_space((self.uid, prompt, items))
|
119 |
else:
|
120 |
items_prime, prompt_prime, rm_cached = self.map_to_pref_embedding_space((self.uid, prompt, items), rm_cached)
|
121 |
-
logger.
|
122 |
-
logger.
|
123 |
-
logger.
|
124 |
-
logger.
|
|
|
125 |
if self.pref_learner_type == 'angle':
|
|
|
126 |
prompt_last_prime = prompt_prime[:, -1, :]
|
127 |
prompt_last_prime = prompt_last_prime.unsqueeze(1)
|
128 |
prompt_last_prime = prompt_last_prime / torch.norm(prompt_last_prime, dim=-1, keepdim=True)
|
@@ -131,8 +133,8 @@ class PrefLearner(BasePrefLearner): # <f(x),f(u)>
|
|
131 |
items_last_prime = items_last_prime / torch.norm(items_last_prime, dim=-1, keepdim=True)
|
132 |
logit_scale = self.logit_scale.exp()
|
133 |
clamped_logit_scale = torch.clamp(logit_scale, max=100)
|
134 |
-
logger.
|
135 |
-
logger.
|
136 |
sim_score = (prompt_last_prime * items_last_prime).sum(dim=-1) * clamped_logit_scale # (bs, max_token_length)
|
137 |
if rm_cached is None:
|
138 |
return sim_score
|
|
|
113 |
|
114 |
def forward(self, x, rm_cached=None):
|
115 |
assert self.uid is not None, "Please specify the user id first by calling specify_user_ids() to personalize the reward model"
|
116 |
+
prompt, items = x
|
117 |
if rm_cached is None:
|
118 |
items_prime, prompt_prime = self.map_to_pref_embedding_space((self.uid, prompt, items))
|
119 |
else:
|
120 |
items_prime, prompt_prime, rm_cached = self.map_to_pref_embedding_space((self.uid, prompt, items), rm_cached)
|
121 |
+
# logger.critical(f"{items_prime[0]=}")
|
122 |
+
# logger.critical(f"{prompt_prime[0]=}")
|
123 |
+
# logger.critical(f"{items_prime.shape=}")
|
124 |
+
# logger.critical(f"{prompt_prime.shape=}")
|
125 |
+
# FIXME: bug exist here
|
126 |
if self.pref_learner_type == 'angle':
|
127 |
+
# FIXME: do the cumulative evaluation!
|
128 |
prompt_last_prime = prompt_prime[:, -1, :]
|
129 |
prompt_last_prime = prompt_last_prime.unsqueeze(1)
|
130 |
prompt_last_prime = prompt_last_prime / torch.norm(prompt_last_prime, dim=-1, keepdim=True)
|
|
|
133 |
items_last_prime = items_last_prime / torch.norm(items_last_prime, dim=-1, keepdim=True)
|
134 |
logit_scale = self.logit_scale.exp()
|
135 |
clamped_logit_scale = torch.clamp(logit_scale, max=100)
|
136 |
+
# logger.critical(f"{prompt_last_prime.shape=}")
|
137 |
+
# logger.critical(f"{items_last_prime.shape=}")
|
138 |
sim_score = (prompt_last_prime * items_last_prime).sum(dim=-1) * clamped_logit_scale # (bs, max_token_length)
|
139 |
if rm_cached is None:
|
140 |
return sim_score
|
pytorch_model.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1334487698
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8c5e8e5083c6c333b9ba3284e989dab95ef00f70e1b97770df255acebada4388
|
3 |
size 1334487698
|
userLearner.py
CHANGED
@@ -92,6 +92,11 @@ class UserLearner(nn.Module):
|
|
92 |
|
93 |
# embeds shape: (bs, seq_len, hid_dim)
|
94 |
shape = embeds.shape
|
|
|
|
|
|
|
|
|
|
|
95 |
embeds = embeds.view(-1, shape[-1]) # (bs*seq_len, hid_dim)
|
96 |
# g(embeds) shape: (bs*seq_len, hid_dim) -> (bs*seq_len, pref_dim)
|
97 |
logits = torch.stack([g(embeds).view(shape[0], shape[1], -1) for g in self.projectors.values()],dim=1)
|
@@ -118,8 +123,8 @@ class UserLearner(nn.Module):
|
|
118 |
# assert sum(mix_weight) == 1
|
119 |
# w = self.softmax(mix_weight.repeat(bs, 1))
|
120 |
# w = mix_weight.repeat(bs, 1)
|
121 |
-
logger.info(f"{w=}")
|
122 |
-
logger.info(f"{w.shape=}")
|
123 |
w = w.unsqueeze(-1).unsqueeze(-1)
|
124 |
y_hat = (w * prompt_logits).sum(dim=1)
|
125 |
self.tmp_store_user_ideal_points = y_hat
|
|
|
92 |
|
93 |
# embeds shape: (bs, seq_len, hid_dim)
|
94 |
shape = embeds.shape
|
95 |
+
# only last hidden state start
|
96 |
+
embeds = embeds[:, -1, :] # (bs, seq_len, hid_dim) -> (bs, hid_dim)
|
97 |
+
embeds = embeds.unsqueeze(1).repeat(1, shape[1], 1) # (bs, hid_dim) -> (bs, seq_len, hid_dim)
|
98 |
+
# only last hidden state end
|
99 |
+
# logger.critical("using only last hidden state of prompt tokens")
|
100 |
embeds = embeds.view(-1, shape[-1]) # (bs*seq_len, hid_dim)
|
101 |
# g(embeds) shape: (bs*seq_len, hid_dim) -> (bs*seq_len, pref_dim)
|
102 |
logits = torch.stack([g(embeds).view(shape[0], shape[1], -1) for g in self.projectors.values()],dim=1)
|
|
|
123 |
# assert sum(mix_weight) == 1
|
124 |
# w = self.softmax(mix_weight.repeat(bs, 1))
|
125 |
# w = mix_weight.repeat(bs, 1)
|
126 |
+
# logger.info(f"{w=}")
|
127 |
+
# logger.info(f"{w.shape=}")
|
128 |
w = w.unsqueeze(-1).unsqueeze(-1)
|
129 |
y_hat = (w * prompt_logits).sum(dim=1)
|
130 |
self.tmp_store_user_ideal_points = y_hat
|