File size: 15,627 Bytes
4395cf9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 |
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
PretrainedConfig,
PreTrainedModel,
AutoConfig,
AutoModel,
AutoModelForCausalLM
)
class ArgonneConfig(PretrainedConfig):
model_type = "argonne"
def __init__(self, vocab_size=12000, block_size=2048, n_layer=24, n_head=24, n_embd=1296, dropout=0.1, use_flash_attn=True, **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.block_size = block_size
self.n_layer = n_layer
self.n_head = n_head
self.n_embd = n_embd
self.dropout = dropout
self.use_flash_attn = use_flash_attn
class Block(nn.Module):
def __init__(self, config):
super().__init__()
self.ln1 = nn.LayerNorm(config.n_embd)
self.attn = CausalSelfAttention(config)
self.ln2 = nn.LayerNorm(config.n_embd)
self.mlp = MLP(config)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.mlp(self.ln2(x))
return x
class CausalSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0, "Embedding dim must be divisible by n_head"
self.n_head = config.n_head
self.head_dim = config.n_embd // config.n_head
self.query = nn.Linear(config.n_embd, config.n_embd)
self.key = nn.Linear(config.n_embd, config.n_embd)
self.value = nn.Linear(config.n_embd, config.n_embd)
self.attn_drop = nn.Dropout(config.dropout)
self.resid_drop = nn.Dropout(config.dropout)
self.proj = nn.Linear(config.n_embd, config.n_embd)
self.use_flash_attn = getattr(config, 'use_flash_attn', True)
# Register the causal mask for the traditional attention path
self.register_buffer(
"mask",
torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size)
)
def forward(self, x):
b, t, c = x.size()
q = self.query(x).view(b, t, self.n_head, self.head_dim).transpose(1, 2)
k = self.key(x).view(b, t, self.n_head, self.head_dim).transpose(1, 2)
v = self.value(x).view(b, t, self.n_head, self.head_dim).transpose(1, 2)
if hasattr(F, 'scaled_dot_product_attention') and self.use_flash_attn:
# When using is_causal=True, don't provide an attention mask
attn_output = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p if self.training else 0.0,
is_causal=True # Let PyTorch handle the causal mask internally
)
attn_output = attn_output.transpose(1, 2).contiguous().view(b, t, c)
y = self.resid_drop(self.proj(attn_output))
return y
else:
# Original attention implementation (fallback)
att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
att = att.masked_fill(self.mask[:, :, :t, :t] == 0, float('-inf'))
att = torch.softmax(att, dim=-1)
att = self.attn_drop(att)
y = att @ v
y = y.transpose(1, 2).contiguous().view(b, t, c)
y = self.resid_drop(self.proj(y))
return y
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.fc1 = nn.Linear(config.n_embd, 4 * config.n_embd)
self.act = nn.GELU()
self.fc2 = nn.Linear(4 * config.n_embd, config.n_embd)
self.drop = nn.Dropout(config.dropout)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class ArgonneModel(PreTrainedModel):
config_class = ArgonneConfig
def __init__(self, config, device_map=None):
super().__init__(config)
# Create embeddings on CPU initially
self.token_embedding = nn.Embedding(config.vocab_size, config.n_embd)
self.position_embedding = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
self.drop = nn.Dropout(config.dropout)
# Build all blocks
self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
# Final LayerNorm + output head
self.ln_f = nn.LayerNorm(config.n_embd)
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
nn.init.normal_(self.position_embedding, mean=0.0, std=0.02)
self.post_init()
# For pipeline parallelism
self.pipeline_stages = None
self.devices = []
# Handle device_map="auto" for inference
if device_map is not None:
self.setup_device_map(device_map)
def setup_device_map(self, device_map):
"""
Set up the model on devices according to device_map.
If device_map="auto", use accelerate to automatically assign model parts to devices.
"""
if device_map == "auto":
try:
from accelerate import dispatch_model
from accelerate.utils import infer_auto_device_map
# Get device map automatically
auto_device_map = infer_auto_device_map(self)
# Dispatch model across devices
dispatch_model(self, device_map=auto_device_map)
print(f"Model automatically distributed across devices with device_map: {auto_device_map}")
except ImportError:
print("The 'accelerate' library is required for device_map='auto'. Please install it with 'pip install accelerate'.")
print("Continuing with model on CPU or default device.")
else:
# Handle custom device map
# This would be a more complex implementation where the user provides a specific mapping
# of model components to devices
pass
def distribute_model(self, device_ids=None):
"""
Distribute the model blocks across multiple GPU devices in a pipeline style.
If 'device_ids' is None, we'll discover all available GPUs.
"""
if device_ids is None:
num_gpus = torch.cuda.device_count()
if num_gpus < 1:
raise ValueError("No GPUs found—can't do pipeline parallel on CPU only.")
device_ids = [f"cuda:{i}" for i in range(num_gpus)]
# Store them so the training loop can keep referencing model.devices
self.devices = [torch.device(d) for d in device_ids]
self.pipeline_stages = nn.ModuleList()
num_gpus = len(device_ids)
blocks_per_gpu = math.ceil(len(self.blocks) / num_gpus)
start_idx = 0
for i in range(num_gpus):
end_idx = min(start_idx + blocks_per_gpu, len(self.blocks))
stage_blocks = self.blocks[start_idx:end_idx]
stage = nn.Sequential(*stage_blocks).to(device_ids[i])
self.pipeline_stages.append(stage)
start_idx = end_idx
if end_idx >= len(self.blocks):
break
# Move embeddings to the first device
first_device = device_ids[0]
self.token_embedding = self.token_embedding.to(first_device)
# For nn.Parameter, we need to move the data, not replace the parameter
self.position_embedding.data = self.position_embedding.data.to(first_device)
self.drop = self.drop.to(first_device)
# Move final LayerNorm + head to the last device
last_device = device_ids[-1]
self.ln_f = self.ln_f.to(last_device)
self.head = self.head.to(last_device)
print(f"Model distributed across {len(device_ids)} devices")
print(f"First device: {first_device}, Last device: {last_device}")
print(f"Transformer layers per device: ~{blocks_per_gpu}")
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
def prepare_for_compile(self):
"""
Prepare model for torch.compile() by ensuring all components
are compatible with the compiler.
"""
# Some models may need special handling for compilation
# For now, we'll just return self since our model structure should be compatible
return self
def forward(self, idx, targets=None):
"""
If self.pipeline_stages is None, we do a normal single-device forward
(whatever device everything is currently on—CPU or a single GPU).
Otherwise, we do a pipeline parallel forward.
"""
# Make the forward method more compiler-friendly
if idx.dim() == 1:
# Add batch dimension if missing
idx = idx.unsqueeze(0)
# Rest of the forward method remains the same
if self.pipeline_stages is None:
# Single-device forward pass
device = self.token_embedding.weight.device
idx = idx.to(device)
b, t = idx.size()
assert t <= self.config.block_size, "Sequence length exceeds block size"
token_embeddings = self.token_embedding(idx)
position_embeddings = self.position_embedding[:, :t, :]
hidden_states = self.drop(token_embeddings + position_embeddings)
for block in self.blocks:
hidden_states = block(hidden_states)
hidden_states = self.ln_f(hidden_states)
logits = self.head(hidden_states)
loss = None
if targets is not None:
targets = targets.to(device)
logits = logits.view(-1, logits.size(-1))
targets = targets.view(-1)
loss = F.cross_entropy(logits, targets)
return logits, loss
else:
# Pipeline parallel forward
first_device = next(self.token_embedding.parameters()).device
last_device = next(self.ln_f.parameters()).device
x = idx.to(first_device)
b, t = x.size()
assert t <= self.config.block_size, "Sequence length exceeds block size"
token_embeddings = self.token_embedding(x)
position_embeddings = self.position_embedding[:, :t, :]
hidden_states = self.drop(token_embeddings + position_embeddings)
# Pass through each pipeline stage in sequence
for stage_idx, stage in enumerate(self.pipeline_stages):
device_stage = next(stage.parameters()).device
hidden_states = hidden_states.to(device_stage)
hidden_states = stage(hidden_states)
# Explicitly move to last device before final operations
hidden_states = hidden_states.to(last_device)
hidden_states = self.ln_f(hidden_states)
logits = self.head(hidden_states)
loss = None
if targets is not None:
targets = targets.to(last_device)
logits = logits.view(-1, logits.size(-1))
targets = targets.view(-1)
loss = F.cross_entropy(logits, targets)
return logits, loss
@torch.no_grad()
def generate(self, input_ids, max_new_tokens, temperature=0.7, top_k=None, top_p=None, sample=True):
"""
Generate text using the model.
Args:
input_ids: Input token IDs to continue from
max_new_tokens: Number of tokens to generate
temperature: Temperature for sampling (higher = more random)
top_k: If set, only sample from the top k most likely tokens
top_p: If set, sample from the smallest set of tokens whose cumulative probability exceeds p
sample: If True, sample from the distribution; if False, use greedy decoding
Returns:
Tensor containing the input_ids extended with max_new_tokens generated tokens
"""
self.eval()
# Determine which device to use - explicitly use first device for consistency
if self.pipeline_stages is not None and len(self.devices) > 0:
device = self.devices[0] # Always use first device for generation
else:
device = next(self.parameters()).device
# Ensure input is on the correct device
generated = input_ids.to(device)
for _ in range(max_new_tokens):
# Truncate if necessary to fit within the model's context window
if generated.shape[1] > self.config.block_size:
generated = generated[:, -self.config.block_size:]
# Forward pass
logits, _ = self.forward(generated)
# Make sure logits are on the same device
logits = logits.to(device)
# Get logits for the last token only
logits = logits[:, -1, :]
# Apply temperature
if temperature != 1.0:
logits = logits / temperature
# Greedy decoding (argmax) if sample=False
if not sample:
next_token = torch.argmax(logits, dim=-1, keepdim=True)
else:
# Sampling logic
# Apply top-k filtering
if top_k is not None:
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits = logits.masked_fill(indices_to_remove, float('-inf'))
# Apply top-p (nucleus) filtering
if top_p is not None:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(
dim=1, index=sorted_indices, src=sorted_indices_to_remove
)
logits = logits.masked_fill(indices_to_remove, float('-inf'))
# Convert to probability distribution and sample
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
# Ensure next_token is on the same device before concatenation
next_token = next_token.to(device)
# Append the generated token to the sequence
generated = torch.cat((generated, next_token), dim=1)
return generated
# Register the model with Hugging Face's Auto classes
AutoConfig.register("argonne", ArgonneConfig)
AutoModel.register(ArgonneConfig, ArgonneModel)
AutoModelForCausalLM.register(ArgonneConfig, ArgonneModel)
|