|
|
|
|
|
|
|
''' |
|
@Desc: This is the implementation of PAL-B |
|
''' |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from transformers import AutoModel, AutoConfig |
|
|
|
from .connector import Connector |
|
from .tensor_initializer import TensorInitializer |
|
from .custom_sfx import CustomSoftMax |
|
from .itemLearner import ItemLearner |
|
from .userLearner import UserLearner |
|
|
|
from collections import defaultdict |
|
from typing import Literal, Optional, Tuple |
|
|
|
import logging |
|
logger = logging.getLogger(__name__) |
|
|
|
class BasePrefLearner(nn.Module): |
|
def __init__( |
|
self, |
|
d_hid: int, |
|
d_pref: int, |
|
k: int, |
|
llm_name: str, |
|
pref_learner_type: Literal["dist","dist_normalization","angle","norm","dist_logistic","angle_hinge"], |
|
proj_arch: str, |
|
initializer_type: Literal["gaussian"], |
|
is_expectation_norm_init: bool, |
|
sfx_type: Literal["gumbel_softmax", "softmax"], |
|
sfx_temperature: float, |
|
is_temperature_learnable: bool, |
|
is_gumbel_hard: Optional[bool]=None, |
|
is_partition: bool=False, |
|
seed: int=42, |
|
**kwargs |
|
): |
|
super().__init__() |
|
self.pref_learner_type = pref_learner_type |
|
self.is_temperature_learnable = is_temperature_learnable |
|
|
|
model_config = AutoConfig.from_pretrained(llm_name) |
|
self.llm = AutoModel.from_pretrained(llm_name,from_tf=bool(".ckpt" in llm_name),config=model_config) |
|
self.tensor_initializer = TensorInitializer(initializer_type, seed, is_expectation_norm_init=is_expectation_norm_init) |
|
self.projector_f = Connector(cnct_arch=proj_arch,in_dims=d_hid,out_dims=d_pref) |
|
self.projectors_gk = [Connector(cnct_arch=proj_arch,in_dims=d_hid,out_dims=d_pref) for _ in range(k)] |
|
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) |
|
self.softmax_w = CustomSoftMax(sfx_type=sfx_type, |
|
temperature=sfx_temperature, |
|
is_temperature_learnable=is_temperature_learnable, |
|
is_gumbel_hard=is_gumbel_hard) |
|
self.item_learner = ItemLearner( |
|
llm = self.llm, |
|
projector=self.projector_f, |
|
) |
|
self.is_partition = is_partition |
|
self.user_learner = UserLearner(k=k, llm=self.llm, projectors=self.projectors_gk, softmax=self.softmax_w, is_partition=is_partition) |
|
logger.critical('🛑 Remember to call update_trainable_params() after the model is initialized.') |
|
|
|
def update_trainable_params(self, fix_modules: Tuple[str,...]=()): |
|
|
|
self.trainable_params = defaultdict(list) |
|
if "llm" not in fix_modules: |
|
self.trainable_params["llm"] = self.llm.parameters() |
|
else: |
|
self.llm.eval() |
|
if "itemLearnerProjector" not in fix_modules: |
|
self.trainable_params["projector_f"].extend(self.item_learner.projector.parameters()) |
|
if "userLearnerProjector" not in fix_modules: |
|
self.trainable_params["projectors_gk"].extend(list(self.user_learner.projectors.parameters())) |
|
if "W" not in fix_modules: |
|
self.trainable_params["W"] = self.user_learner.W.parameters() |
|
if self.pref_learner_type in ["angle","dist_logistic"] and "logit_scale" not in fix_modules: |
|
self.trainable_params["logit_scale"] = self.logit_scale |
|
if self.is_temperature_learnable and "temperature" not in fix_modules: |
|
self.trainable_params["temperature"] = self.softmax_w.temperature |
|
|
|
def map_to_pref_embedding_space(self, x, rm_cached=None): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
uid, prompt, items = x |
|
if rm_cached is None: |
|
items_prime = self.item_learner(items) |
|
prompt_prime = self.user_learner(uid, prompt) |
|
return items_prime, prompt_prime |
|
else: |
|
items_prime, rm_cached = self.item_learner(items, rm_cached) |
|
prompt_prime, rm_cached = self.user_learner(uid, prompt, rm_cached) |
|
return items_prime, prompt_prime, rm_cached |
|
|
|
class PrefLearner(BasePrefLearner): |
|
|
|
def __init__(self,*args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
def specify_user_ids(self, uid): |
|
self.uid = uid |
|
|
|
def forward(self, x, rm_cached=None): |
|
assert self.uid is not None, "Please specify the user id first by calling specify_user_ids() to personalize the reward model" |
|
prompt, items = x |
|
if rm_cached is None: |
|
items_prime, prompt_prime = self.map_to_pref_embedding_space((self.uid, prompt, items)) |
|
else: |
|
items_prime, prompt_prime, rm_cached = self.map_to_pref_embedding_space((self.uid, prompt, items), rm_cached) |
|
|
|
|
|
|
|
|
|
if self.pref_learner_type == 'angle': |
|
|
|
prompt_last_prime = prompt_prime[:, -1, :] |
|
prompt_last_prime = prompt_last_prime.unsqueeze(1) |
|
prompt_last_prime = prompt_last_prime / torch.norm(prompt_last_prime, dim=-1, keepdim=True) |
|
items_last_prime = items_prime[:, -1, :] |
|
items_last_prime = items_last_prime.unsqueeze(1) |
|
items_last_prime = items_last_prime / torch.norm(items_last_prime, dim=-1, keepdim=True) |
|
logit_scale = self.logit_scale.exp() |
|
clamped_logit_scale = torch.clamp(logit_scale, max=100) |
|
|
|
|
|
sim_score = (prompt_last_prime * items_last_prime).sum(dim=-1) * clamped_logit_scale |
|
if rm_cached is None: |
|
return sim_score |
|
else: |
|
return sim_score, rm_cached |
|
else: |
|
raise NotImplementedError |
|
|