daiweichen commited on
Commit
e424d6b
·
verified ·
1 Parent(s): 3a2aa34

Upload PAL_B_RM_opt

Browse files
Files changed (3) hide show
  1. config.json +13 -1
  2. configuration_pal_b_rm.py +2 -0
  3. modeling_pal_b_rm.py +3 -1
config.json CHANGED
@@ -20,5 +20,17 @@
20
  "sfx_temperature": 1.0,
21
  "sfx_type": "softmax",
22
  "torch_dtype": "float32",
23
- "transformers_version": "4.44.2"
 
 
 
 
 
 
 
 
 
 
 
 
24
  }
 
20
  "sfx_temperature": 1.0,
21
  "sfx_type": "softmax",
22
  "torch_dtype": "float32",
23
+ "transformers_version": "4.44.2",
24
+ "uids": [
25
+ "KZL1qeRzHNYSfDAuOctL1iyVV8WC5N",
26
+ "ZzGCcAhvqF0HnKxNsUjtJFadcZdyZj",
27
+ "p4Oh7rUGyLe1EpilJFWr9sPDpkO016",
28
+ "qo6WIyEh27cwAjWpA3Q60J7NaDxzQJ",
29
+ "zKV8BFGy60O0q7102ALF84S6Jo5i4q",
30
+ "i8YiBZlrYmlkkChr5b9BUKvDO6lR1d",
31
+ "M3icahkfAtC9CJrtKgQ7qvyZ5SD8wC",
32
+ "HNzkrs9geGu1YMMfZ5Qvdt0ZaCthfB",
33
+ "Jxv4hxfb9zTVa5nsMDFlnjSX5LZ8MK",
34
+ "UhQipwcpQmiGJmScocXOGOKyCBaFUg"
35
+ ]
36
  }
configuration_pal_b_rm.py CHANGED
@@ -17,6 +17,7 @@ class PAL_B_Config(PretrainedConfig):
17
  sfx_temperature: float = 1.0,
18
  is_temperature_learnable: bool = False,
19
  is_gumbel_hard: bool = None,
 
20
  **kwargs,
21
  ):
22
  self.d_hid = d_hid
@@ -31,5 +32,6 @@ class PAL_B_Config(PretrainedConfig):
31
  self.sfx_temperature = sfx_temperature
32
  self.is_temperature_learnable = is_temperature_learnable
33
  self.is_gumbel_hard = is_gumbel_hard
 
34
  super().__init__(**kwargs)
35
 
 
17
  sfx_temperature: float = 1.0,
18
  is_temperature_learnable: bool = False,
19
  is_gumbel_hard: bool = None,
20
+ uids: list = None,
21
  **kwargs,
22
  ):
23
  self.d_hid = d_hid
 
32
  self.sfx_temperature = sfx_temperature
33
  self.is_temperature_learnable = is_temperature_learnable
34
  self.is_gumbel_hard = is_gumbel_hard
35
+ self.uids = uids
36
  super().__init__(**kwargs)
37
 
modeling_pal_b_rm.py CHANGED
@@ -20,8 +20,10 @@ class PAL_B_RM_opt(PreTrainedModel):
20
  sfx_temperature=config.sfx_temperature,
21
  is_temperature_learnable=config.is_temperature_learnable,
22
  is_gumbel_hard=config.is_gumbel_hard,
 
23
  )
24
- # self.model.user_learner.init_weight(uids)
 
25
 
26
  def forward(self, x):
27
  logits = self.model(x)
 
20
  sfx_temperature=config.sfx_temperature,
21
  is_temperature_learnable=config.is_temperature_learnable,
22
  is_gumbel_hard=config.is_gumbel_hard,
23
+ uids=config.uids,
24
  )
25
+ if config.uids is not None:
26
+ self.model.user_learner.init_weight(config.uids)
27
 
28
  def forward(self, x):
29
  logits = self.model(x)