File size: 1,088 Bytes
3a2aa34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e424d6b
3a2aa34
e424d6b
 
3a2aa34
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from transformers import PreTrainedModel
from .learner import PrefLearner
from .configuration_pal_b_rm import PAL_B_Config

class PAL_B_RM_opt(PreTrainedModel):
    config_class = PAL_B_Config
    
    def __init__(self, config):
        super().__init__(config)
        self.model = PrefLearner(
            d_hid=config.d_hid,
            d_pref=config.d_pref,
            k=config.k,
            llm_name=config.llm_name,
            pref_learner_type=config.pref_learner_type,
            proj_arch=config.proj_arch,
            initializer_type=config.initializer_type,
            is_expectation_norm_init=config.is_expectation_norm_init,
            sfx_type=config.sfx_type,
            sfx_temperature=config.sfx_temperature,
            is_temperature_learnable=config.is_temperature_learnable,
            is_gumbel_hard=config.is_gumbel_hard,
            uids=config.uids,
        )
        if config.uids is not None:
            self.model.user_learner.init_weight(config.uids)
    
    def forward(self, x):
        logits = self.model(x)
        return {'logits': logits}