Youzhi Yu commited on
Commit
ea46d13
·
1 Parent(s): 5324bdd

Fix generate method to handle CausalLMOutput, plus other updates

Browse files
Files changed (1) hide show
  1. model.py +48 -98
model.py CHANGED
@@ -9,6 +9,7 @@ from transformers import (
9
  AutoModel,
10
  AutoModelForCausalLM
11
  )
 
12
 
13
  from typing import Optional
14
 
@@ -102,6 +103,9 @@ class MLP(nn.Module):
102
  class ArgonneModel(PreTrainedModel):
103
  config_class = ArgonneConfig
104
 
 
 
 
105
  def __init__(self, config, device_map=None):
106
  super().__init__(config)
107
  # Create embeddings on CPU initially
@@ -214,18 +218,40 @@ class ArgonneModel(PreTrainedModel):
214
  # For now, we'll just return self since our model structure should be compatible
215
  return self
216
 
217
- def forward(self, idx, targets=None):
 
 
 
 
 
 
218
  """
219
- If self.pipeline_stages is None, we do a normal single-device forward
220
- (whatever device everything is currently on—CPU or a single GPU).
221
- Otherwise, we do a pipeline parallel forward.
 
 
 
 
 
 
222
  """
223
- # Make the forward method more compiler-friendly
 
 
 
 
 
 
 
 
 
224
  if idx.dim() == 1:
225
- # Add batch dimension if missing
226
  idx = idx.unsqueeze(0)
227
-
228
- # Rest of the forward method remains the same
 
 
229
  if self.pipeline_stages is None:
230
  # Single-device forward pass
231
  device = self.token_embedding.weight.device
@@ -250,7 +276,11 @@ class ArgonneModel(PreTrainedModel):
250
  targets = targets.view(-1)
251
  loss = F.cross_entropy(logits, targets)
252
 
253
- return logits, loss
 
 
 
 
254
  else:
255
  # Pipeline parallel forward
256
  first_device = next(self.token_embedding.parameters()).device
@@ -270,7 +300,7 @@ class ArgonneModel(PreTrainedModel):
270
  hidden_states = hidden_states.to(device_stage)
271
  hidden_states = stage(hidden_states)
272
 
273
- # Explicitly move to last device before final operations
274
  hidden_states = hidden_states.to(last_device)
275
  hidden_states = self.ln_f(hidden_states)
276
  logits = self.head(hidden_states)
@@ -282,7 +312,11 @@ class ArgonneModel(PreTrainedModel):
282
  targets = targets.view(-1)
283
  loss = F.cross_entropy(logits, targets)
284
 
285
- return logits, loss
 
 
 
 
286
 
287
 
288
  @torch.no_grad()
@@ -342,8 +376,9 @@ class ArgonneModel(PreTrainedModel):
342
  generated = generated[:, -self.config.block_size:]
343
 
344
  # Forward pass
345
- logits, _ = self.forward(generated)
346
- logits = logits[:, -1, :] # get the last token's logits
 
347
 
348
  # Temperature
349
  if temperature != 1.0:
@@ -382,91 +417,6 @@ class ArgonneModel(PreTrainedModel):
382
 
383
  return generated
384
 
385
-
386
- # @torch.no_grad()
387
- # def generate(self, input_ids, max_new_tokens, temperature=0.7, top_k=None, top_p=None, sample=True):
388
- # """
389
- # Generate text using the model.
390
-
391
- # Args:
392
- # input_ids: Input token IDs to continue from
393
- # max_new_tokens: Number of tokens to generate
394
- # temperature: Temperature for sampling (higher = more random)
395
- # top_k: If set, only sample from the top k most likely tokens
396
- # top_p: If set, sample from the smallest set of tokens whose cumulative probability exceeds p
397
- # sample: If True, sample from the distribution; if False, use greedy decoding
398
-
399
- # Returns:
400
- # Tensor containing the input_ids extended with max_new_tokens generated tokens
401
- # """
402
- # self.eval()
403
-
404
- # # Determine which device to use - explicitly use first device for consistency
405
- # if self.pipeline_stages is not None and len(self.devices) > 0:
406
- # device = self.devices[0] # Always use first device for generation
407
- # else:
408
- # device = next(self.parameters()).device
409
-
410
- # # Ensure input is on the correct device
411
- # generated = input_ids.to(device)
412
-
413
- # for _ in range(max_new_tokens):
414
- # # Truncate if necessary to fit within the model's context window
415
- # if generated.shape[1] > self.config.block_size:
416
- # generated = generated[:, -self.config.block_size:]
417
-
418
- # # Forward pass
419
- # logits, _ = self.forward(generated)
420
-
421
- # # Make sure logits are on the same device
422
- # logits = logits.to(device)
423
-
424
- # # Get logits for the last token only
425
- # logits = logits[:, -1, :]
426
-
427
- # # Apply temperature
428
- # if temperature != 1.0:
429
- # logits = logits / temperature
430
-
431
- # # Greedy decoding (argmax) if sample=False
432
- # if not sample:
433
- # next_token = torch.argmax(logits, dim=-1, keepdim=True)
434
- # else:
435
- # # Sampling logic
436
- # # Apply top-k filtering
437
- # if top_k is not None:
438
- # indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
439
- # logits = logits.masked_fill(indices_to_remove, float('-inf'))
440
-
441
- # # Apply top-p (nucleus) filtering
442
- # if top_p is not None:
443
- # sorted_logits, sorted_indices = torch.sort(logits, descending=True)
444
- # cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
445
-
446
- # # Remove tokens with cumulative probability above the threshold
447
- # sorted_indices_to_remove = cumulative_probs > top_p
448
-
449
- # # Shift the indices to the right to keep the first token above the threshold
450
- # sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
451
- # sorted_indices_to_remove[..., 0] = 0
452
-
453
- # indices_to_remove = sorted_indices_to_remove.scatter(
454
- # dim=1, index=sorted_indices, src=sorted_indices_to_remove
455
- # )
456
- # logits = logits.masked_fill(indices_to_remove, float('-inf'))
457
-
458
- # # Convert to probability distribution and sample
459
- # probs = F.softmax(logits, dim=-1)
460
- # next_token = torch.multinomial(probs, num_samples=1)
461
-
462
- # # Ensure next_token is on the same device before concatenation
463
- # next_token = next_token.to(device)
464
-
465
- # # Append the generated token to the sequence
466
- # generated = torch.cat((generated, next_token), dim=1)
467
-
468
- # return generated
469
-
470
  # Register the model with Hugging Face's Auto classes
471
  AutoConfig.register("argonne", ArgonneConfig)
472
  AutoModel.register(ArgonneConfig, ArgonneModel)
 
9
  AutoModel,
10
  AutoModelForCausalLM
11
  )
12
+ from transformers.modeling_outputs import CausalLMOutput
13
 
14
  from typing import Optional
15
 
 
103
  class ArgonneModel(PreTrainedModel):
104
  config_class = ArgonneConfig
105
 
106
+ # for map_device = "auto"
107
+ _no_split_modules = ["Block"]
108
+
109
  def __init__(self, config, device_map=None):
110
  super().__init__(config)
111
  # Create embeddings on CPU initially
 
218
  # For now, we'll just return self since our model structure should be compatible
219
  return self
220
 
221
+ def forward(
222
+ self,
223
+ input_ids=None,
224
+ attention_mask=None,
225
+ labels=None,
226
+ **kwargs
227
+ ):
228
  """
229
+ HF-friendly forward method.
230
+
231
+ Args:
232
+ input_ids (torch.LongTensor): Tokens to be fed to the model. [batch_size, seq_len].
233
+ attention_mask (torch.LongTensor, optional): Mask of shape [batch_size, seq_len],
234
+ with 1 for actual tokens and 0 for padding, if you want to incorporate it.
235
+ Currently ignored in this minimal example.
236
+ labels (torch.LongTensor, optional): Targets for language modeling, same shape as `input_ids`.
237
+ **kwargs: Catch-all for any additional arguments (e.g. past_key_values) so we don't crash.
238
  """
239
+ # 1) We'll rename the parameters from the old code
240
+ if input_ids is None:
241
+ raise ValueError("`input_ids` must be provided.")
242
+
243
+ # We used to call it 'idx'
244
+ idx = input_ids
245
+ # We used to call it 'targets'
246
+ targets = labels
247
+
248
+ # [Optional] If we want to handle single-dim input_ids
249
  if idx.dim() == 1:
 
250
  idx = idx.unsqueeze(0)
251
+
252
+ # 2) Now the rest of your old forward logic remains, just replacing references
253
+ # to "idx" and "targets" with these new variables.
254
+
255
  if self.pipeline_stages is None:
256
  # Single-device forward pass
257
  device = self.token_embedding.weight.device
 
276
  targets = targets.view(-1)
277
  loss = F.cross_entropy(logits, targets)
278
 
279
+ return CausalLMOutput(
280
+ loss=loss,
281
+ logits=logits,
282
+ )
283
+
284
  else:
285
  # Pipeline parallel forward
286
  first_device = next(self.token_embedding.parameters()).device
 
300
  hidden_states = hidden_states.to(device_stage)
301
  hidden_states = stage(hidden_states)
302
 
303
+ # Move to last device before final ops
304
  hidden_states = hidden_states.to(last_device)
305
  hidden_states = self.ln_f(hidden_states)
306
  logits = self.head(hidden_states)
 
312
  targets = targets.view(-1)
313
  loss = F.cross_entropy(logits, targets)
314
 
315
+ return CausalLMOutput(
316
+ loss=loss,
317
+ logits=logits,
318
+ )
319
+
320
 
321
 
322
  @torch.no_grad()
 
376
  generated = generated[:, -self.config.block_size:]
377
 
378
  # Forward pass
379
+ outputs = self.forward(generated)
380
+ logits = outputs.logits # outputs is a CausalLMOutput
381
+ logits = logits[:, -1, :] # get the last token's logits
382
 
383
  # Temperature
384
  if temperature != 1.0:
 
417
 
418
  return generated
419
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
  # Register the model with Hugging Face's Auto classes
421
  AutoConfig.register("argonne", ArgonneConfig)
422
  AutoModel.register(ArgonneConfig, ArgonneModel)