pal-b-large-opt-350m / learner.py
daiweichen's picture
Upload PAL_B_RM_opt
002a82b verified
#!/usr/bin/env python
# -*-coding:utf-8 -*-
'''
@Desc: This is the implementation of PAL-B
'''
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoConfig
from .connector import Connector
from .tensor_initializer import TensorInitializer
from .custom_sfx import CustomSoftMax
from .itemLearner import ItemLearner
from .userLearner import UserLearner
from collections import defaultdict
from typing import Literal, Optional, Tuple
import logging
logger = logging.getLogger(__name__)
class BasePrefLearner(nn.Module):
def __init__(
self,
d_hid: int,
d_pref: int,
k: int,
llm_name: str,
pref_learner_type: Literal["dist","dist_normalization","angle","norm","dist_logistic","angle_hinge"],
proj_arch: str,
initializer_type: Literal["gaussian"],
is_expectation_norm_init: bool, # the tensor initialization parameters
sfx_type: Literal["gumbel_softmax", "softmax"],
sfx_temperature: float,
is_temperature_learnable: bool,
is_gumbel_hard: Optional[bool]=None,
is_partition: bool=False,
seed: int=42,
**kwargs
):
super().__init__()
self.pref_learner_type = pref_learner_type
self.is_temperature_learnable = is_temperature_learnable
# init all necessary modules
model_config = AutoConfig.from_pretrained(llm_name)
self.llm = AutoModel.from_pretrained(llm_name,from_tf=bool(".ckpt" in llm_name),config=model_config)
self.tensor_initializer = TensorInitializer(initializer_type, seed, is_expectation_norm_init=is_expectation_norm_init)
self.projector_f = Connector(cnct_arch=proj_arch,in_dims=d_hid,out_dims=d_pref)
self.projectors_gk = [Connector(cnct_arch=proj_arch,in_dims=d_hid,out_dims=d_pref) for _ in range(k)]
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.softmax_w = CustomSoftMax(sfx_type=sfx_type,
temperature=sfx_temperature,
is_temperature_learnable=is_temperature_learnable,
is_gumbel_hard=is_gumbel_hard)
self.item_learner = ItemLearner(
llm = self.llm,
projector=self.projector_f,
)
self.is_partition = is_partition
self.user_learner = UserLearner(k=k, llm=self.llm, projectors=self.projectors_gk, softmax=self.softmax_w, is_partition=is_partition)
logger.critical('🛑 Remember to call update_trainable_params() after the model is initialized.')
def update_trainable_params(self, fix_modules: Tuple[str,...]=()):
# capture params
self.trainable_params = defaultdict(list)
if "llm" not in fix_modules:
self.trainable_params["llm"] = self.llm.parameters()
else:
self.llm.eval()
if "itemLearnerProjector" not in fix_modules:
self.trainable_params["projector_f"].extend(self.item_learner.projector.parameters())
if "userLearnerProjector" not in fix_modules:
self.trainable_params["projectors_gk"].extend(list(self.user_learner.projectors.parameters()))
if "W" not in fix_modules:
self.trainable_params["W"] = self.user_learner.W.parameters()
if self.pref_learner_type in ["angle","dist_logistic"] and "logit_scale" not in fix_modules:
self.trainable_params["logit_scale"] = self.logit_scale
if self.is_temperature_learnable and "temperature" not in fix_modules:
self.trainable_params["temperature"] = self.softmax_w.temperature
def map_to_pref_embedding_space(self, x, rm_cached=None):
# (
# uid,
# {
# 'input_ids': prompt_input_ids,\
# 'attention_mask': prompt_attention_mask,
# },\
# {
# 'input_ids': eval_input_ids,\
# 'attention_mask': eval_attention_mask,\
# })
uid, prompt, items = x
if rm_cached is None:
items_prime = self.item_learner(items)
prompt_prime = self.user_learner(uid, prompt)
return items_prime, prompt_prime
else:
items_prime, rm_cached = self.item_learner(items, rm_cached)
prompt_prime, rm_cached = self.user_learner(uid, prompt, rm_cached)
return items_prime, prompt_prime, rm_cached
class PrefLearner(BasePrefLearner): # <f(x),f(u)>
def __init__(self,*args, **kwargs):
super().__init__(*args, **kwargs)
def specify_user_ids(self, uid): # personalize the model for a specific user
self.uid = uid
def forward(self, x, rm_cached=None):
assert self.uid is not None, "Please specify the user id first by calling specify_user_ids() to personalize the reward model"
prompt, items = x
if rm_cached is None:
items_prime, prompt_prime = self.map_to_pref_embedding_space((self.uid, prompt, items))
else:
items_prime, prompt_prime, rm_cached = self.map_to_pref_embedding_space((self.uid, prompt, items), rm_cached)
# logger.critical(f"{items_prime[0]=}")
# logger.critical(f"{prompt_prime[0]=}")
# logger.critical(f"{items_prime.shape=}")
# logger.critical(f"{prompt_prime.shape=}")
if self.pref_learner_type == 'angle':
# NOTICE: here we implement the "last token only" version of PAL-B
prompt_last_prime = prompt_prime[:, -1, :]
prompt_last_prime = prompt_last_prime.unsqueeze(1)
prompt_last_prime = prompt_last_prime / torch.norm(prompt_last_prime, dim=-1, keepdim=True)
items_last_prime = items_prime[:, -1, :]
items_last_prime = items_last_prime.unsqueeze(1)
items_last_prime = items_last_prime / torch.norm(items_last_prime, dim=-1, keepdim=True)
logit_scale = self.logit_scale.exp()
clamped_logit_scale = torch.clamp(logit_scale, max=100)
# logger.critical(f"{prompt_last_prime.shape=}")
# logger.critical(f"{items_last_prime.shape=}")
sim_score = (prompt_last_prime * items_last_prime).sum(dim=-1) * clamped_logit_scale # (bs, max_token_length)
if rm_cached is None:
return sim_score
else:
return sim_score, rm_cached
else:
raise NotImplementedError