import torch import torch.nn as nn import torch.nn.functional as F from .connector import Connector from .projector import Projector from .tensor_merger import TensorMerger import numpy as np from typing import Literal, Optional, Tuple import logging logging.basicConfig(level=logging.WARNING, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) class ItemLearner(nn.Module): llm: nn.Module projector: nn.Module def __init__(self, llm, projector): super().__init__() self.llm = llm self.projector = projector def forward(self, x, rm_cached=None): # only pass the generated data ''' x = {'input_ids': torch.tensor, 'attention_mask': torch.tensor} ''' input_ids = x['input_ids'] attention_mask = x['attention_mask'] # logger.critical(f"ItemLearner: {input_ids=}") # logger.critical(f"ItemLearner: {attention_mask=}") if rm_cached is None: llm_res = self.llm( input_ids=input_ids, attention_mask=attention_mask, ) else: llm_res = self.llm( input_ids=input_ids[:, -1:], # attention_mask=attention_mask, past_key_values=rm_cached["item_learner"], use_cache=True ) rm_cached["item_learner"] = llm_res.past_key_values embeds = llm_res.last_hidden_state # logger.critical(f"ItemLearner: {embeds=}") # embeds shape: (bs, seq_len, hidden_size) shape = embeds.shape embeds = embeds.view(-1, shape[-1]) # (bs*seq_len, hidden_size) projected_embeds = self.projector(embeds) if rm_cached is None: return projected_embeds.view(shape[0], shape[1], -1) else: return projected_embeds.view(shape[0], shape[1], -1), rm_cached