ByteGPT-small / modeling_bytegpt.py
ijktech-jk's picture
Upload ByteGPT-small
d95f1d1 verified
raw
history blame
11.8 kB
import torch
import torch.nn as nn
from torch.nn import functional as F
from torchvision import models
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import CausalLMOutput
from .configuration_bytegpt import ByteGPTConfig
try:
from flash_attn.flash_attention import FlashAttention
FLASH_ATTENTION_AVAILABLE = (
True and torch.cuda.is_available()
) # Only available on CUDA
except ImportError:
FLASH_ATTENTION_AVAILABLE = False
class Head(nn.Module):
"""One head of self-attention.
Args:
head_size (int): The size of the head.
n_embd (int): The embedding dimension.
block_size (int): The block size.
dropout (float): The dropout rate.
use_flash_attention (bool): Whether to use Flash Attention.
Attributes:
key (nn.Linear): The linear layer for computing the keys.
query (nn.Linear): The linear layer for computing the queries.
value (nn.Linear): The linear layer for computing the values.
tril (torch.Tensor): The lower triangular matrix.
dropout (nn.Dropout): The dropout layer.
use_flash_attention (bool): Whether to use Flash Attention.
flash_attention (FlashAttention): The FlashAttention module.
"""
def __init__(
self,
head_size: int,
n_embd: int,
block_size: int,
dropout: float,
use_flash_attention: bool = False,
) -> None:
super().__init__()
self.key = nn.Linear(n_embd, head_size, bias=False)
self.query = nn.Linear(n_embd, head_size, bias=False)
self.value = nn.Linear(n_embd, head_size, bias=False)
self.dropout = nn.Dropout(dropout)
# Only enable flash attention if we're on CUDA
self.use_flash_attention = use_flash_attention and FLASH_ATTENTION_AVAILABLE
if self.use_flash_attention:
print("Using Flash Attention")
self.flash_attention = FlashAttention()
else:
if use_flash_attention:
print(
"Flash Attention requested but not available. Using standard attention."
)
self.tril = torch.tril(torch.ones(block_size, block_size))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Perform forward pass through the attention head.
Args:
x (torch.Tensor): The input tensor of shape (batch_size, sequence_length, embedding_dimension).
Returns:
torch.Tensor: The output tensor of shape (batch_size, sequence_length, embedding_dimension).
"""
B, T, C = x.shape
k = self.key(x) # (B,T,head_size)
q = self.query(x) # (B,T,head_size)
v = self.value(x) # (B,T,head_size)
if self.use_flash_attention:
# Flash Attention expects shape (B, H, T, D)
out = self.flash_attention(q.unsqueeze(1), k.unsqueeze(1), v.unsqueeze(1))[
0
].squeeze(1)
else:
# Regular attention
self.tril = self.tril.to(x.device)
wei = q @ k.transpose(-2, -1) * k.shape[-1] ** -0.5 # (B, T, T)
wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf")) # (B, T, T)
wei = F.softmax(wei, dim=-1) # (B, T, T)
wei = self.dropout(wei)
out = wei @ v # (B, T, head_size)
return out
class MultiHeadAttention(nn.Module):
"""Multiple heads of self-attention in parallel.
Args:
num_heads (int): The number of heads.
head_size (int): The size of each head.
n_embd (int): The embedding dimension.
block_size (int): The block size.
dropout (float): The dropout rate.
use_flash_attention (bool): Whether to use Flash Attention.
Attributes:
heads (nn.Modulelist): The list of attention heads.
proj (nn.Linear): The linear layer for projecting the concatenated heads.
dropout (nn.Dropout): The dropout layer.
"""
def __init__(
self,
num_heads: int,
head_size: int,
n_embd: int,
block_size: int,
dropout: float,
use_flash_attention: bool = False,
) -> None:
super().__init__()
self.heads = nn.ModuleList(
[
Head(
head_size,
n_embd,
block_size,
dropout,
use_flash_attention=use_flash_attention,
)
for _ in range(num_heads)
]
)
self.proj = nn.Linear(n_embd, n_embd)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Perform forward pass through the multi-head attention layer.
Args:
x (torch.Tensor): The input tensor of shape (batch_size, sequence_length, embedding_dimension).
Returns:
torch.Tensor: The output tensor of shape (batch_size, sequence_length, embedding_dimension).
"""
out = torch.cat([h(x) for h in self.heads], dim=-1)
out = self.dropout(self.proj(out))
return out
class FeedForward(nn.Module):
"""Simple linear layer followed by a non-linearity.
Args:
n_embd (int): The embedding dimension.
dropout (float): The dropout rate.
Attributes:
net (nn.Sequential): The sequential network of linear layers and ReLU activation.
"""
def __init__(self, n_embd: int, dropout: float) -> None:
super().__init__()
self.net = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd),
nn.ReLU(),
nn.Linear(4 * n_embd, n_embd),
nn.Dropout(dropout),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Perform forward pass through the feedforward layer.
Args:
x (torch.Tensor): The input tensor of shape (batch_size, sequence_length, embedding_dimension).
Returns:
torch.Tensor: The output tensor of shape (batch_size, sequence_length, embedding_dimension).
"""
return self.net(x)
class Block(nn.Module):
"""Transformer block: communication followed by computation.
Args:
n_embd (int): The embedding dimension.
n_head (int): The number of attention heads.
block_size (int): The block size.
dropout (float): The dropout rate.
use_flash_attention (bool): Whether to use Flash Attention.
Attributes:
sa (MultiHeadAttention): The multi-head attention layer.
ffwd (FeedForward): The feedforward layer.
ln1 (nn.LayerNorm): The layer normalization layer for the first sublayer.
ln2 (nn.LayerNorm): The layer normalization layer for the second sublayer.
"""
def __init__(
self,
n_embd: int,
n_head: int,
block_size: int,
dropout: float,
use_flash_attention: bool = False,
) -> None:
super().__init__()
head_size = n_embd // n_head
self.sa = MultiHeadAttention(
n_head,
head_size,
n_embd,
block_size,
dropout,
use_flash_attention=use_flash_attention,
)
self.ffwd = FeedForward(n_embd, dropout)
self.ln1 = nn.LayerNorm(n_embd)
self.ln2 = nn.LayerNorm(n_embd)
# Remove duplicate flash attention and tril setup since it's handled in Head class
self.use_flash_attention = use_flash_attention and FLASH_ATTENTION_AVAILABLE
if self.use_flash_attention:
print("Using Flash Attention")
elif use_flash_attention:
print(
"Flash Attention requested but not available. Using standard attention."
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Perform forward pass through the transformer block.
Args:
x (torch.Tensor): The input tensor of shape (batch_size, sequence_length, embedding_dimension).
Returns:
torch.Tensor: The output tensor of shape (batch_size, sequence_length, embedding_dimension).
"""
x = x + self.sa(self.ln1(x))
x = x + self.ffwd(self.ln2(x))
return x
class ByteGPTForCausalLM(PreTrainedModel):
config_class = ByteGPTConfig
def __init__(
self,
config: ByteGPTConfig,
):
super().__init__(config)
self.block_size = config.block_size
self.token_embedding_table = nn.Embedding(config.vocab_size, config.n_embd)
self.position_embedding_table = nn.Embedding(config.block_size, config.n_embd)
self.blocks = nn.Sequential(
*[
Block(
config.n_embd,
config.n_head,
config.block_size,
config.dropout,
config.use_flash_attention,
)
for _ in range(config.n_layer)
]
)
self.ln_f = nn.LayerNorm(config.n_embd)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
return_dict: bool = True,
labels: torch.Tensor = None,
**kwargs
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass of the model.
Args:
idx: Input tensor.
targets: Target tensor.
Returns:
tuple of logits and loss.
"""
B, T = input_ids.shape
# Token and position embeddings
tok_emb = self.token_embedding_table(input_ids) # (B,T,C)
pos_emb = self.position_embedding_table(
torch.arange(T, device=input_ids.device)
) # (T,C)
x = tok_emb + pos_emb # (B,T,C)
# Transformer blocks
x = self.blocks(x) # (B,T,C)
x = self.ln_f(x) # (B,T,C)
# Language model head
logits = self.lm_head(x) # (B,T,vocab_size)
if labels is None:
loss = None
else:
B, T, C = logits.shape
logits = logits.view(B * T, C)
labels = labels.view(B * T)
loss = F.cross_entropy(logits, labels)
if not return_dict:
return (logits, loss)
return CausalLMOutput(logits=logits, loss=loss)
def prepare_inputs_for_generation(self, input_ids, **kwargs):
# Required for .generate() to work
return {
"input_ids": input_ids,
"attention_mask": torch.ones_like(input_ids),
}
# def generate(
# self, input_ids: torch.Tensor, max_new_tokens: int, temperature: float = 1.0
# ) -> torch.Tensor:
# """
# Generate text tokens autoregressively.
# Args:
# idx: Context tokens
# max_new_tokens: Number of tokens to generate
# temperature: Sampling temperature (higher = more random)
# Returns:
# Generated token sequence
# """
# for _ in range(max_new_tokens):
# # Crop context if needed
# idx_cond = input_ids[:, -self.block_size :]
# # Get predictions
# logits, _ = self(idx_cond)
# # Focus only on the last time step
# logits = logits[:, -1, :] / temperature
# # Apply softmax to get probabilities
# probs = F.softmax(logits, dim=-1)
# # Sample from the distribution
# idx_next = torch.multinomial(probs, num_samples=1)
# # Append sampled index to the running sequence
# idx = torch.cat((idx, idx_next), dim=1)
# return idx