Upload PAL_B_RM_opt
Browse files- config.json +13 -1
- configuration_pal_b_rm.py +2 -0
- modeling_pal_b_rm.py +3 -1
config.json
CHANGED
@@ -20,5 +20,17 @@
|
|
20 |
"sfx_temperature": 1.0,
|
21 |
"sfx_type": "softmax",
|
22 |
"torch_dtype": "float32",
|
23 |
-
"transformers_version": "4.44.2"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
}
|
|
|
20 |
"sfx_temperature": 1.0,
|
21 |
"sfx_type": "softmax",
|
22 |
"torch_dtype": "float32",
|
23 |
+
"transformers_version": "4.44.2",
|
24 |
+
"uids": [
|
25 |
+
"KZL1qeRzHNYSfDAuOctL1iyVV8WC5N",
|
26 |
+
"ZzGCcAhvqF0HnKxNsUjtJFadcZdyZj",
|
27 |
+
"p4Oh7rUGyLe1EpilJFWr9sPDpkO016",
|
28 |
+
"qo6WIyEh27cwAjWpA3Q60J7NaDxzQJ",
|
29 |
+
"zKV8BFGy60O0q7102ALF84S6Jo5i4q",
|
30 |
+
"i8YiBZlrYmlkkChr5b9BUKvDO6lR1d",
|
31 |
+
"M3icahkfAtC9CJrtKgQ7qvyZ5SD8wC",
|
32 |
+
"HNzkrs9geGu1YMMfZ5Qvdt0ZaCthfB",
|
33 |
+
"Jxv4hxfb9zTVa5nsMDFlnjSX5LZ8MK",
|
34 |
+
"UhQipwcpQmiGJmScocXOGOKyCBaFUg"
|
35 |
+
]
|
36 |
}
|
configuration_pal_b_rm.py
CHANGED
@@ -17,6 +17,7 @@ class PAL_B_Config(PretrainedConfig):
|
|
17 |
sfx_temperature: float = 1.0,
|
18 |
is_temperature_learnable: bool = False,
|
19 |
is_gumbel_hard: bool = None,
|
|
|
20 |
**kwargs,
|
21 |
):
|
22 |
self.d_hid = d_hid
|
@@ -31,5 +32,6 @@ class PAL_B_Config(PretrainedConfig):
|
|
31 |
self.sfx_temperature = sfx_temperature
|
32 |
self.is_temperature_learnable = is_temperature_learnable
|
33 |
self.is_gumbel_hard = is_gumbel_hard
|
|
|
34 |
super().__init__(**kwargs)
|
35 |
|
|
|
17 |
sfx_temperature: float = 1.0,
|
18 |
is_temperature_learnable: bool = False,
|
19 |
is_gumbel_hard: bool = None,
|
20 |
+
uids: list = None,
|
21 |
**kwargs,
|
22 |
):
|
23 |
self.d_hid = d_hid
|
|
|
32 |
self.sfx_temperature = sfx_temperature
|
33 |
self.is_temperature_learnable = is_temperature_learnable
|
34 |
self.is_gumbel_hard = is_gumbel_hard
|
35 |
+
self.uids = uids
|
36 |
super().__init__(**kwargs)
|
37 |
|
modeling_pal_b_rm.py
CHANGED
@@ -20,8 +20,10 @@ class PAL_B_RM_opt(PreTrainedModel):
|
|
20 |
sfx_temperature=config.sfx_temperature,
|
21 |
is_temperature_learnable=config.is_temperature_learnable,
|
22 |
is_gumbel_hard=config.is_gumbel_hard,
|
|
|
23 |
)
|
24 |
-
|
|
|
25 |
|
26 |
def forward(self, x):
|
27 |
logits = self.model(x)
|
|
|
20 |
sfx_temperature=config.sfx_temperature,
|
21 |
is_temperature_learnable=config.is_temperature_learnable,
|
22 |
is_gumbel_hard=config.is_gumbel_hard,
|
23 |
+
uids=config.uids,
|
24 |
)
|
25 |
+
if config.uids is not None:
|
26 |
+
self.model.user_learner.init_weight(config.uids)
|
27 |
|
28 |
def forward(self, x):
|
29 |
logits = self.model(x)
|