Argonne-1.5 / model.py
Youzhi Yu
Initial commit of Argonne-1.5 model
4395cf9
raw
history blame
15.6 kB
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
PretrainedConfig,
PreTrainedModel,
AutoConfig,
AutoModel,
AutoModelForCausalLM
)
class ArgonneConfig(PretrainedConfig):
model_type = "argonne"
def __init__(self, vocab_size=12000, block_size=2048, n_layer=24, n_head=24, n_embd=1296, dropout=0.1, use_flash_attn=True, **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.block_size = block_size
self.n_layer = n_layer
self.n_head = n_head
self.n_embd = n_embd
self.dropout = dropout
self.use_flash_attn = use_flash_attn
class Block(nn.Module):
def __init__(self, config):
super().__init__()
self.ln1 = nn.LayerNorm(config.n_embd)
self.attn = CausalSelfAttention(config)
self.ln2 = nn.LayerNorm(config.n_embd)
self.mlp = MLP(config)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.mlp(self.ln2(x))
return x
class CausalSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0, "Embedding dim must be divisible by n_head"
self.n_head = config.n_head
self.head_dim = config.n_embd // config.n_head
self.query = nn.Linear(config.n_embd, config.n_embd)
self.key = nn.Linear(config.n_embd, config.n_embd)
self.value = nn.Linear(config.n_embd, config.n_embd)
self.attn_drop = nn.Dropout(config.dropout)
self.resid_drop = nn.Dropout(config.dropout)
self.proj = nn.Linear(config.n_embd, config.n_embd)
self.use_flash_attn = getattr(config, 'use_flash_attn', True)
# Register the causal mask for the traditional attention path
self.register_buffer(
"mask",
torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size)
)
def forward(self, x):
b, t, c = x.size()
q = self.query(x).view(b, t, self.n_head, self.head_dim).transpose(1, 2)
k = self.key(x).view(b, t, self.n_head, self.head_dim).transpose(1, 2)
v = self.value(x).view(b, t, self.n_head, self.head_dim).transpose(1, 2)
if hasattr(F, 'scaled_dot_product_attention') and self.use_flash_attn:
# When using is_causal=True, don't provide an attention mask
attn_output = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p if self.training else 0.0,
is_causal=True # Let PyTorch handle the causal mask internally
)
attn_output = attn_output.transpose(1, 2).contiguous().view(b, t, c)
y = self.resid_drop(self.proj(attn_output))
return y
else:
# Original attention implementation (fallback)
att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
att = att.masked_fill(self.mask[:, :, :t, :t] == 0, float('-inf'))
att = torch.softmax(att, dim=-1)
att = self.attn_drop(att)
y = att @ v
y = y.transpose(1, 2).contiguous().view(b, t, c)
y = self.resid_drop(self.proj(y))
return y
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.fc1 = nn.Linear(config.n_embd, 4 * config.n_embd)
self.act = nn.GELU()
self.fc2 = nn.Linear(4 * config.n_embd, config.n_embd)
self.drop = nn.Dropout(config.dropout)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class ArgonneModel(PreTrainedModel):
config_class = ArgonneConfig
def __init__(self, config, device_map=None):
super().__init__(config)
# Create embeddings on CPU initially
self.token_embedding = nn.Embedding(config.vocab_size, config.n_embd)
self.position_embedding = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
self.drop = nn.Dropout(config.dropout)
# Build all blocks
self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
# Final LayerNorm + output head
self.ln_f = nn.LayerNorm(config.n_embd)
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
nn.init.normal_(self.position_embedding, mean=0.0, std=0.02)
self.post_init()
# For pipeline parallelism
self.pipeline_stages = None
self.devices = []
# Handle device_map="auto" for inference
if device_map is not None:
self.setup_device_map(device_map)
def setup_device_map(self, device_map):
"""
Set up the model on devices according to device_map.
If device_map="auto", use accelerate to automatically assign model parts to devices.
"""
if device_map == "auto":
try:
from accelerate import dispatch_model
from accelerate.utils import infer_auto_device_map
# Get device map automatically
auto_device_map = infer_auto_device_map(self)
# Dispatch model across devices
dispatch_model(self, device_map=auto_device_map)
print(f"Model automatically distributed across devices with device_map: {auto_device_map}")
except ImportError:
print("The 'accelerate' library is required for device_map='auto'. Please install it with 'pip install accelerate'.")
print("Continuing with model on CPU or default device.")
else:
# Handle custom device map
# This would be a more complex implementation where the user provides a specific mapping
# of model components to devices
pass
def distribute_model(self, device_ids=None):
"""
Distribute the model blocks across multiple GPU devices in a pipeline style.
If 'device_ids' is None, we'll discover all available GPUs.
"""
if device_ids is None:
num_gpus = torch.cuda.device_count()
if num_gpus < 1:
raise ValueError("No GPUs found—can't do pipeline parallel on CPU only.")
device_ids = [f"cuda:{i}" for i in range(num_gpus)]
# Store them so the training loop can keep referencing model.devices
self.devices = [torch.device(d) for d in device_ids]
self.pipeline_stages = nn.ModuleList()
num_gpus = len(device_ids)
blocks_per_gpu = math.ceil(len(self.blocks) / num_gpus)
start_idx = 0
for i in range(num_gpus):
end_idx = min(start_idx + blocks_per_gpu, len(self.blocks))
stage_blocks = self.blocks[start_idx:end_idx]
stage = nn.Sequential(*stage_blocks).to(device_ids[i])
self.pipeline_stages.append(stage)
start_idx = end_idx
if end_idx >= len(self.blocks):
break
# Move embeddings to the first device
first_device = device_ids[0]
self.token_embedding = self.token_embedding.to(first_device)
# For nn.Parameter, we need to move the data, not replace the parameter
self.position_embedding.data = self.position_embedding.data.to(first_device)
self.drop = self.drop.to(first_device)
# Move final LayerNorm + head to the last device
last_device = device_ids[-1]
self.ln_f = self.ln_f.to(last_device)
self.head = self.head.to(last_device)
print(f"Model distributed across {len(device_ids)} devices")
print(f"First device: {first_device}, Last device: {last_device}")
print(f"Transformer layers per device: ~{blocks_per_gpu}")
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
def prepare_for_compile(self):
"""
Prepare model for torch.compile() by ensuring all components
are compatible with the compiler.
"""
# Some models may need special handling for compilation
# For now, we'll just return self since our model structure should be compatible
return self
def forward(self, idx, targets=None):
"""
If self.pipeline_stages is None, we do a normal single-device forward
(whatever device everything is currently on—CPU or a single GPU).
Otherwise, we do a pipeline parallel forward.
"""
# Make the forward method more compiler-friendly
if idx.dim() == 1:
# Add batch dimension if missing
idx = idx.unsqueeze(0)
# Rest of the forward method remains the same
if self.pipeline_stages is None:
# Single-device forward pass
device = self.token_embedding.weight.device
idx = idx.to(device)
b, t = idx.size()
assert t <= self.config.block_size, "Sequence length exceeds block size"
token_embeddings = self.token_embedding(idx)
position_embeddings = self.position_embedding[:, :t, :]
hidden_states = self.drop(token_embeddings + position_embeddings)
for block in self.blocks:
hidden_states = block(hidden_states)
hidden_states = self.ln_f(hidden_states)
logits = self.head(hidden_states)
loss = None
if targets is not None:
targets = targets.to(device)
logits = logits.view(-1, logits.size(-1))
targets = targets.view(-1)
loss = F.cross_entropy(logits, targets)
return logits, loss
else:
# Pipeline parallel forward
first_device = next(self.token_embedding.parameters()).device
last_device = next(self.ln_f.parameters()).device
x = idx.to(first_device)
b, t = x.size()
assert t <= self.config.block_size, "Sequence length exceeds block size"
token_embeddings = self.token_embedding(x)
position_embeddings = self.position_embedding[:, :t, :]
hidden_states = self.drop(token_embeddings + position_embeddings)
# Pass through each pipeline stage in sequence
for stage_idx, stage in enumerate(self.pipeline_stages):
device_stage = next(stage.parameters()).device
hidden_states = hidden_states.to(device_stage)
hidden_states = stage(hidden_states)
# Explicitly move to last device before final operations
hidden_states = hidden_states.to(last_device)
hidden_states = self.ln_f(hidden_states)
logits = self.head(hidden_states)
loss = None
if targets is not None:
targets = targets.to(last_device)
logits = logits.view(-1, logits.size(-1))
targets = targets.view(-1)
loss = F.cross_entropy(logits, targets)
return logits, loss
@torch.no_grad()
def generate(self, input_ids, max_new_tokens, temperature=0.7, top_k=None, top_p=None, sample=True):
"""
Generate text using the model.
Args:
input_ids: Input token IDs to continue from
max_new_tokens: Number of tokens to generate
temperature: Temperature for sampling (higher = more random)
top_k: If set, only sample from the top k most likely tokens
top_p: If set, sample from the smallest set of tokens whose cumulative probability exceeds p
sample: If True, sample from the distribution; if False, use greedy decoding
Returns:
Tensor containing the input_ids extended with max_new_tokens generated tokens
"""
self.eval()
# Determine which device to use - explicitly use first device for consistency
if self.pipeline_stages is not None and len(self.devices) > 0:
device = self.devices[0] # Always use first device for generation
else:
device = next(self.parameters()).device
# Ensure input is on the correct device
generated = input_ids.to(device)
for _ in range(max_new_tokens):
# Truncate if necessary to fit within the model's context window
if generated.shape[1] > self.config.block_size:
generated = generated[:, -self.config.block_size:]
# Forward pass
logits, _ = self.forward(generated)
# Make sure logits are on the same device
logits = logits.to(device)
# Get logits for the last token only
logits = logits[:, -1, :]
# Apply temperature
if temperature != 1.0:
logits = logits / temperature
# Greedy decoding (argmax) if sample=False
if not sample:
next_token = torch.argmax(logits, dim=-1, keepdim=True)
else:
# Sampling logic
# Apply top-k filtering
if top_k is not None:
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits = logits.masked_fill(indices_to_remove, float('-inf'))
# Apply top-p (nucleus) filtering
if top_p is not None:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(
dim=1, index=sorted_indices, src=sorted_indices_to_remove
)
logits = logits.masked_fill(indices_to_remove, float('-inf'))
# Convert to probability distribution and sample
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
# Ensure next_token is on the same device before concatenation
next_token = next_token.to(device)
# Append the generated token to the sequence
generated = torch.cat((generated, next_token), dim=1)
return generated
# Register the model with Hugging Face's Auto classes
AutoConfig.register("argonne", ArgonneConfig)
AutoModel.register(ArgonneConfig, ArgonneModel)
AutoModelForCausalLM.register(ArgonneConfig, ArgonneModel)