Argonne-1.5 / model.py
Youzhi Yu
Fix generate method to handle CausalLMOutput, plus other updates
ea46d13
raw
history blame
17.1 kB
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
PretrainedConfig,
PreTrainedModel,
AutoConfig,
AutoModel,
AutoModelForCausalLM
)
from transformers.modeling_outputs import CausalLMOutput
from typing import Optional
class ArgonneConfig(PretrainedConfig):
model_type = "argonne"
def __init__(self, vocab_size=12000, block_size=2048, n_layer=24, n_head=24, n_embd=1296, dropout=0.1, use_flash_attn=True, **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.block_size = block_size
self.n_layer = n_layer
self.n_head = n_head
self.n_embd = n_embd
self.dropout = dropout
self.use_flash_attn = use_flash_attn
class Block(nn.Module):
def __init__(self, config):
super().__init__()
self.ln1 = nn.LayerNorm(config.n_embd)
self.attn = CausalSelfAttention(config)
self.ln2 = nn.LayerNorm(config.n_embd)
self.mlp = MLP(config)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.mlp(self.ln2(x))
return x
class CausalSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0, "Embedding dim must be divisible by n_head"
self.n_head = config.n_head
self.head_dim = config.n_embd // config.n_head
self.query = nn.Linear(config.n_embd, config.n_embd)
self.key = nn.Linear(config.n_embd, config.n_embd)
self.value = nn.Linear(config.n_embd, config.n_embd)
self.attn_drop = nn.Dropout(config.dropout)
self.resid_drop = nn.Dropout(config.dropout)
self.proj = nn.Linear(config.n_embd, config.n_embd)
self.use_flash_attn = getattr(config, 'use_flash_attn', True)
# Register the causal mask for the traditional attention path
self.register_buffer(
"mask",
torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size)
)
def forward(self, x):
b, t, c = x.size()
q = self.query(x).view(b, t, self.n_head, self.head_dim).transpose(1, 2)
k = self.key(x).view(b, t, self.n_head, self.head_dim).transpose(1, 2)
v = self.value(x).view(b, t, self.n_head, self.head_dim).transpose(1, 2)
if hasattr(F, 'scaled_dot_product_attention') and self.use_flash_attn:
# When using is_causal=True, don't provide an attention mask
attn_output = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p if self.training else 0.0,
is_causal=True # Let PyTorch handle the causal mask internally
)
attn_output = attn_output.transpose(1, 2).contiguous().view(b, t, c)
y = self.resid_drop(self.proj(attn_output))
return y
else:
# Original attention implementation (fallback)
att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
att = att.masked_fill(self.mask[:, :, :t, :t] == 0, float('-inf'))
att = torch.softmax(att, dim=-1)
att = self.attn_drop(att)
y = att @ v
y = y.transpose(1, 2).contiguous().view(b, t, c)
y = self.resid_drop(self.proj(y))
return y
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.fc1 = nn.Linear(config.n_embd, 4 * config.n_embd)
self.act = nn.GELU()
self.fc2 = nn.Linear(4 * config.n_embd, config.n_embd)
self.drop = nn.Dropout(config.dropout)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class ArgonneModel(PreTrainedModel):
config_class = ArgonneConfig
# for map_device = "auto"
_no_split_modules = ["Block"]
def __init__(self, config, device_map=None):
super().__init__(config)
# Create embeddings on CPU initially
self.token_embedding = nn.Embedding(config.vocab_size, config.n_embd)
self.position_embedding = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
self.drop = nn.Dropout(config.dropout)
# Build all blocks
self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
# Final LayerNorm + output head
self.ln_f = nn.LayerNorm(config.n_embd)
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
nn.init.normal_(self.position_embedding, mean=0.0, std=0.02)
self.post_init()
# For pipeline parallelism
self.pipeline_stages = None
self.devices = []
# Handle device_map="auto" for inference
if device_map is not None:
self.setup_device_map(device_map)
def setup_device_map(self, device_map):
"""
Set up the model on devices according to device_map.
If device_map="auto", use accelerate to automatically assign model parts to devices.
"""
if device_map == "auto":
try:
from accelerate import dispatch_model
from accelerate.utils import infer_auto_device_map
# Get device map automatically
auto_device_map = infer_auto_device_map(self)
# Dispatch model across devices
dispatch_model(self, device_map=auto_device_map)
print(f"Model automatically distributed across devices with device_map: {auto_device_map}")
except ImportError:
print("The 'accelerate' library is required for device_map='auto'. Please install it with 'pip install accelerate'.")
print("Continuing with model on CPU or default device.")
else:
# Handle custom device map
# This would be a more complex implementation where the user provides a specific mapping
# of model components to devices
pass
def distribute_model(self, device_ids=None):
"""
Distribute the model blocks across multiple GPU devices in a pipeline style.
If 'device_ids' is None, we'll discover all available GPUs.
"""
if device_ids is None:
num_gpus = torch.cuda.device_count()
if num_gpus < 1:
raise ValueError("No GPUs found—can't do pipeline parallel on CPU only.")
device_ids = [f"cuda:{i}" for i in range(num_gpus)]
# Store them so the training loop can keep referencing model.devices
self.devices = [torch.device(d) for d in device_ids]
self.pipeline_stages = nn.ModuleList()
num_gpus = len(device_ids)
blocks_per_gpu = math.ceil(len(self.blocks) / num_gpus)
start_idx = 0
for i in range(num_gpus):
end_idx = min(start_idx + blocks_per_gpu, len(self.blocks))
stage_blocks = self.blocks[start_idx:end_idx]
stage = nn.Sequential(*stage_blocks).to(device_ids[i])
self.pipeline_stages.append(stage)
start_idx = end_idx
if end_idx >= len(self.blocks):
break
# Move embeddings to the first device
first_device = device_ids[0]
self.token_embedding = self.token_embedding.to(first_device)
# For nn.Parameter, we need to move the data, not replace the parameter
self.position_embedding.data = self.position_embedding.data.to(first_device)
self.drop = self.drop.to(first_device)
# Move final LayerNorm + head to the last device
last_device = device_ids[-1]
self.ln_f = self.ln_f.to(last_device)
self.head = self.head.to(last_device)
print(f"Model distributed across {len(device_ids)} devices")
print(f"First device: {first_device}, Last device: {last_device}")
print(f"Transformer layers per device: ~{blocks_per_gpu}")
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
def prepare_for_compile(self):
"""
Prepare model for torch.compile() by ensuring all components
are compatible with the compiler.
"""
# Some models may need special handling for compilation
# For now, we'll just return self since our model structure should be compatible
return self
def forward(
self,
input_ids=None,
attention_mask=None,
labels=None,
**kwargs
):
"""
HF-friendly forward method.
Args:
input_ids (torch.LongTensor): Tokens to be fed to the model. [batch_size, seq_len].
attention_mask (torch.LongTensor, optional): Mask of shape [batch_size, seq_len],
with 1 for actual tokens and 0 for padding, if you want to incorporate it.
Currently ignored in this minimal example.
labels (torch.LongTensor, optional): Targets for language modeling, same shape as `input_ids`.
**kwargs: Catch-all for any additional arguments (e.g. past_key_values) so we don't crash.
"""
# 1) We'll rename the parameters from the old code
if input_ids is None:
raise ValueError("`input_ids` must be provided.")
# We used to call it 'idx'
idx = input_ids
# We used to call it 'targets'
targets = labels
# [Optional] If we want to handle single-dim input_ids
if idx.dim() == 1:
idx = idx.unsqueeze(0)
# 2) Now the rest of your old forward logic remains, just replacing references
# to "idx" and "targets" with these new variables.
if self.pipeline_stages is None:
# Single-device forward pass
device = self.token_embedding.weight.device
idx = idx.to(device)
b, t = idx.size()
assert t <= self.config.block_size, "Sequence length exceeds block size"
token_embeddings = self.token_embedding(idx)
position_embeddings = self.position_embedding[:, :t, :]
hidden_states = self.drop(token_embeddings + position_embeddings)
for block in self.blocks:
hidden_states = block(hidden_states)
hidden_states = self.ln_f(hidden_states)
logits = self.head(hidden_states)
loss = None
if targets is not None:
targets = targets.to(device)
logits = logits.view(-1, logits.size(-1))
targets = targets.view(-1)
loss = F.cross_entropy(logits, targets)
return CausalLMOutput(
loss=loss,
logits=logits,
)
else:
# Pipeline parallel forward
first_device = next(self.token_embedding.parameters()).device
last_device = next(self.ln_f.parameters()).device
x = idx.to(first_device)
b, t = x.size()
assert t <= self.config.block_size, "Sequence length exceeds block size"
token_embeddings = self.token_embedding(x)
position_embeddings = self.position_embedding[:, :t, :]
hidden_states = self.drop(token_embeddings + position_embeddings)
# Pass through each pipeline stage in sequence
for stage_idx, stage in enumerate(self.pipeline_stages):
device_stage = next(stage.parameters()).device
hidden_states = hidden_states.to(device_stage)
hidden_states = stage(hidden_states)
# Move to last device before final ops
hidden_states = hidden_states.to(last_device)
hidden_states = self.ln_f(hidden_states)
logits = self.head(hidden_states)
loss = None
if targets is not None:
targets = targets.to(last_device)
logits = logits.view(-1, logits.size(-1))
targets = targets.view(-1)
loss = F.cross_entropy(logits, targets)
return CausalLMOutput(
loss=loss,
logits=logits,
)
@torch.no_grad()
def generate(
self,
input_ids: Optional[torch.Tensor] = None,
max_length: int = 50, # Standard HF param
do_sample: bool = True, # Replaces "sample=True/False"
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: float = 0.7,
attention_mask: Optional[torch.Tensor] = None,
# Catch-all for additional HF params (e.g. num_beams) so it doesn't crash:
**kwargs
):
"""
A bridging generate method that accepts common HF arguments
but uses your custom GPT-style generation loop.
Args:
input_ids (Tensor): Starting prompt tokens [batch_size, seq_len].
max_length (int): The total length of the final sequence (seq_len + new tokens).
do_sample (bool): If True, sample from distribution; if False, do greedy.
top_k (int): Top-k filtering threshold.
top_p (float): Nucleus sampling threshold.
temperature (float): Sampling temperature.
attention_mask (Tensor): If you want to handle padding (unused in this minimal example).
**kwargs: Ignored extra arguments (e.g. num_beams) so they don't cause an error.
Returns:
Tensor of shape [batch_size, total_seq_len] with the generated tokens.
"""
self.eval()
# 1) Figure out device
if self.pipeline_stages is not None and len(self.devices) > 0:
device = self.devices[0]
else:
device = next(self.parameters()).device
# 2) Sanity checks
if input_ids is None:
raise ValueError("`input_ids` must be provided for generation.")
batch_size, current_length = input_ids.shape
if current_length >= max_length:
raise ValueError(f"Current sequence length {current_length} >= max_length={max_length}")
# 3) Move to the correct device
generated = input_ids.to(device)
# We'll generate new tokens until length == max_length
total_new_tokens = max_length - current_length
for _ in range(total_new_tokens):
# Truncate if necessary to fit within the model's context window
if generated.shape[1] > self.config.block_size:
generated = generated[:, -self.config.block_size:]
# Forward pass
outputs = self.forward(generated)
logits = outputs.logits # outputs is a CausalLMOutput
logits = logits[:, -1, :] # get the last token's logits
# Temperature
if temperature != 1.0:
logits = logits / temperature
# Greedy decode if do_sample=False
if not do_sample:
next_token = torch.argmax(logits, dim=-1, keepdim=True)
else:
# top-k filtering
if top_k is not None:
threshold = torch.topk(logits, top_k)[0][..., -1, None]
filter_mask = logits < threshold
logits = logits.masked_fill(filter_mask, float('-inf'))
# top-p (nucleus) filtering
if top_p is not None:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
# shift right to retain the first token above threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
filter_mask = sorted_indices_to_remove.scatter(
dim=1, index=sorted_indices, src=sorted_indices_to_remove
)
logits = logits.masked_fill(filter_mask, float('-inf'))
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
# Append new token
generated = torch.cat([generated, next_token.to(device)], dim=1)
return generated
# Register the model with Hugging Face's Auto classes
AutoConfig.register("argonne", ArgonneConfig)
AutoModel.register(ArgonneConfig, ArgonneModel)
AutoModelForCausalLM.register(ArgonneConfig, ArgonneModel)