Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from typing import Optional | |
class RMSNorm(nn.Module): | |
""" | |
Root Mean Square Layer Normalization (RMSNorm). | |
""" | |
def __init__(self, hidden_size: int, eps: float = 1e-5): | |
super().__init__() | |
self.weight = nn.Parameter(torch.ones(hidden_size)) | |
self.eps = eps | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
variance = x.pow(2).mean(-1, keepdim=True) | |
x = x * torch.rsqrt(variance + self.eps) | |
return self.weight * x | |
class RotaryPositionalEmbedding(nn.Module): | |
""" | |
Rotary Positional Embedding (RoPE) for transformers. | |
""" | |
def __init__(self, dim: int, theta: float = 10000.0): | |
super().__init__() | |
self.dim = dim | |
self.theta = theta | |
def forward(self, x: torch.Tensor, seq_len: int) -> torch.Tensor: | |
""" | |
Apply rotary positional embedding to the input tensor. | |
Args: | |
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, num_heads, head_dim). | |
seq_len (int): Sequence length. | |
Returns: | |
torch.Tensor: Output tensor with rotary positional embeddings applied. | |
""" | |
batch_size, seq_len, num_heads, head_dim = x.shape | |
# Generate position indices | |
position = torch.arange(seq_len, dtype=torch.float32, device=x.device).unsqueeze(-1) | |
# Generate frequencies | |
freqs = torch.exp( | |
torch.arange(0, head_dim, 2, dtype=torch.float32, device=x.device) * -(torch.log(torch.tensor(self.theta)) / head_dim) | |
) | |
# Compute sinusoids | |
sinusoid = position * freqs | |
sin = torch.sin(sinusoid) | |
cos = torch.cos(sinusoid) | |
# Reshape sin and cos to match the input tensor's shape | |
sin = sin.unsqueeze(0).unsqueeze(2) # Shape: (1, seq_len, 1, head_dim // 2) | |
cos = cos.unsqueeze(0).unsqueeze(2) # Shape: (1, seq_len, 1, head_dim // 2) | |
# Apply rotary embeddings | |
x_rotated = x.clone() | |
x_rotated[..., 0::2] = x[..., 0::2] * cos - x[..., 1::2] * sin | |
x_rotated[..., 1::2] = x[..., 1::2] * cos + x[..., 0::2] * sin | |
return x_rotated | |
from torch.utils.checkpoint import checkpoint | |
class TransformerBlock(nn.Module): | |
""" | |
A single transformer block with self-attention and feed-forward layers. | |
""" | |
def __init__( | |
self, | |
hidden_size: int, | |
num_attention_heads: int, | |
intermediate_size: int, | |
num_key_value_heads: int, | |
rms_norm_eps: float, | |
hidden_act: str = "silu", | |
): | |
super().__init__() | |
self.hidden_size = hidden_size | |
self.num_attention_heads = num_attention_heads | |
self.num_key_value_heads = num_key_value_heads | |
self.head_dim = hidden_size // num_attention_heads | |
# Ensure the hidden size is divisible by the number of attention heads | |
if hidden_size % num_attention_heads != 0: | |
raise ValueError( | |
f"hidden_size ({hidden_size}) must be divisible by num_attention_heads ({num_attention_heads})" | |
) | |
# Self-attention layers | |
self.q_proj = nn.Linear(hidden_size, hidden_size) | |
self.k_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim) | |
self.v_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim) | |
self.o_proj = nn.Linear(hidden_size, hidden_size) | |
# Feed-forward layers | |
self.gate_proj = nn.Linear(hidden_size, intermediate_size) | |
self.up_proj = nn.Linear(hidden_size, intermediate_size) | |
self.down_proj = nn.Linear(intermediate_size, hidden_size) | |
# Normalization layers | |
self.input_norm = RMSNorm(hidden_size, eps=rms_norm_eps) | |
self.post_attention_norm = RMSNorm(hidden_size, eps=rms_norm_eps) | |
# Activation function | |
self.act = nn.SiLU() if hidden_act == "silu" else nn.GELU() | |
# Rotary positional embedding | |
self.rope = RotaryPositionalEmbedding(self.head_dim) | |
def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: | |
def create_custom_forward(module): | |
def custom_forward(*inputs): | |
return module._forward(inputs[0], inputs[1]) | |
return custom_forward | |
# Use gradient checkpointing | |
return checkpoint(create_custom_forward(self), x, attention_mask) | |
def _forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: | |
# Self-attention | |
residual = x | |
x = self.input_norm(x) | |
# Project inputs to query, key, and value | |
batch_size, seq_len, _ = x.shape | |
# Reshape queries for multi-head attention | |
q = self.q_proj(x).view(batch_size, seq_len, self.num_attention_heads, self.head_dim) | |
# Reshape keys and values for key-value heads | |
k = self.k_proj(x).view(batch_size, seq_len, self.num_key_value_heads, self.head_dim) | |
v = self.v_proj(x).view(batch_size, seq_len, self.num_key_value_heads, self.head_dim) | |
# Apply rotary positional embedding | |
q = self.rope(q, seq_len) | |
k = self.rope(k, seq_len) | |
# Scaled dot-product attention | |
attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask) | |
attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, self.hidden_size) | |
attn_output = self.o_proj(attn_output) | |
# Add residual connection | |
x = residual + attn_output | |
# Feed-forward network | |
residual = x | |
x = self.post_attention_norm(x) | |
gate = self.act(self.gate_proj(x)) | |
up = self.up_proj(x) | |
ff_output = self.down_proj(gate * up) | |
# Add residual connection | |
x = residual + ff_output | |
return x | |
class TransformerModel(nn.Module): | |
def __init__( | |
self, | |
vocab_size: int, | |
hidden_size: int, | |
num_hidden_layers: int, | |
num_attention_heads: int, | |
intermediate_size: int, | |
num_key_value_heads: int, | |
max_position_embeddings: int, | |
rms_norm_eps: float, | |
hidden_act: str = "silu", | |
tie_word_embeddings: bool = True, | |
): | |
super().__init__() | |
self.vocab_size = vocab_size | |
self.hidden_size = hidden_size | |
self.num_hidden_layers = num_hidden_layers | |
self.max_position_embeddings = max_position_embeddings | |
# Embedding layers (skip quantization for these) | |
self.embed_tokens = nn.Embedding(vocab_size, hidden_size) | |
self.embed_positions = nn.Embedding(max_position_embeddings, hidden_size) | |
# Transformer blocks | |
self.layers = nn.ModuleList([ | |
TransformerBlock( | |
hidden_size=hidden_size, | |
num_attention_heads=num_attention_heads, | |
intermediate_size=intermediate_size, | |
num_key_value_heads=num_key_value_heads, | |
rms_norm_eps=rms_norm_eps, | |
hidden_act=hidden_act, | |
) | |
for _ in range(num_hidden_layers) | |
]) | |
# Final normalization layer | |
self.final_norm = RMSNorm(hidden_size, eps=rms_norm_eps) | |
# Output layer (tied to input embeddings if specified) | |
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False) | |
if tie_word_embeddings: | |
self.lm_head.weight = self.embed_tokens.weight | |
def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: | |
# Embed tokens and positions | |
seq_len = input_ids.size(1) | |
position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device) | |
token_embeddings = self.embed_tokens(input_ids) | |
position_embeddings = self.embed_positions(position_ids) | |
x = token_embeddings + position_embeddings | |
# Pass through transformer layers | |
for layer in self.layers: | |
x = layer(x, attention_mask) | |
# Final normalization | |
x = self.final_norm(x) | |
# Output logits | |
logits = self.lm_head(x) | |
return logits | |
def generate( | |
self, | |
input_ids: torch.Tensor, | |
max_length: int = 50, | |
temperature: float = 1.0, | |
top_k: int = 50, | |
do_sample: bool = True, | |
) -> torch.Tensor: | |
""" | |
Generate text autoregressively. | |
Args: | |
input_ids (torch.Tensor): Input token IDs of shape (batch_size, seq_len). | |
max_length (int): Maximum length of the generated sequence. | |
temperature (float): Sampling temperature. Higher values mean more random sampling. | |
top_k (int): Top-k sampling. Only the top-k tokens are considered. | |
do_sample (bool): Whether to sample from the distribution or take the argmax. | |
Returns: | |
torch.Tensor: Generated token IDs of shape (batch_size, max_length). | |
""" | |
self.eval() | |
with torch.no_grad(): | |
for _ in range(max_length - input_ids.size(1)): | |
# Get the logits for the last token | |
logits = self(input_ids)[:, -1, :] | |
# Apply temperature | |
logits = logits / temperature | |
# Top-k sampling | |
if top_k > 0: | |
top_k_values, top_k_indices = torch.topk(logits, top_k) | |
logits[logits < top_k_values[:, -1].unsqueeze(-1)] = -float("Inf") | |
# Convert logits to probabilities | |
probs = F.softmax(logits, dim=-1) | |
# Sample or take the argmax | |
if do_sample: | |
next_token = torch.multinomial(probs, num_samples=1) | |
else: | |
next_token = torch.argmax(probs, dim=-1, keepdim=True) | |
# Append the next token to the input_ids | |
input_ids = torch.cat([input_ids, next_token], dim=-1) | |
return input_ids | |
# Create the model based on the configuration | |
def create_model_from_config(config: dict) -> TransformerModel: | |
model_config = config["model"]["model_config"] | |
return TransformerModel( | |
vocab_size=model_config["vocab_size"], | |
hidden_size=model_config["hidden_size"], | |
num_hidden_layers=model_config["num_hidden_layers"], | |
num_attention_heads=model_config["num_attention_heads"], | |
intermediate_size=model_config["intermediate_size"], | |
num_key_value_heads=model_config["num_key_value_heads"], | |
max_position_embeddings=model_config["max_position_embeddings"], | |
rms_norm_eps=model_config["rms_norm_eps"], | |
hidden_act=model_config["hidden_act"], | |
tie_word_embeddings=model_config["tie_word_embeddings"], | |
) | |
# Example usage | |
if __name__ == "__main__": | |
import json | |
# Load the configuration file | |
with open("config_smollm2_135M.json", "r") as f: | |
config = json.load(f) | |
# Create the model | |
model = create_model_from_config(config) | |
print(model) |