Last commit not found
from typing import Optional, Tuple | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from .mlp import LlamaMLP | |
from .config import LlamaConfig | |
from .rms_norm import LlamaRMSNorm | |
from .decoder import LlamaDecoderLayer | |
class LlamaModel(nn.Module): | |
def __init__(self, config: LlamaConfig): | |
super().__init__() | |
self.config = config | |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=None) | |
self.layers = nn.ModuleList([LlamaDecoderLayer(config, i) for i in range(config.num_hidden_layers)]) | |
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
def forward( | |
self, | |
input_ids: torch.LongTensor, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
) -> torch.Tensor: | |
hidden_states = self.embed_tokens(input_ids) | |
for decoder_layer in self.layers: | |
hidden_states = decoder_layer( | |
hidden_states, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
) | |
hidden_states = self.norm(hidden_states) | |
return hidden_states |