Youzhi Yu
commited on
Commit
·
ea46d13
1
Parent(s):
5324bdd
Fix generate method to handle CausalLMOutput, plus other updates
Browse files
model.py
CHANGED
@@ -9,6 +9,7 @@ from transformers import (
|
|
9 |
AutoModel,
|
10 |
AutoModelForCausalLM
|
11 |
)
|
|
|
12 |
|
13 |
from typing import Optional
|
14 |
|
@@ -102,6 +103,9 @@ class MLP(nn.Module):
|
|
102 |
class ArgonneModel(PreTrainedModel):
|
103 |
config_class = ArgonneConfig
|
104 |
|
|
|
|
|
|
|
105 |
def __init__(self, config, device_map=None):
|
106 |
super().__init__(config)
|
107 |
# Create embeddings on CPU initially
|
@@ -214,18 +218,40 @@ class ArgonneModel(PreTrainedModel):
|
|
214 |
# For now, we'll just return self since our model structure should be compatible
|
215 |
return self
|
216 |
|
217 |
-
def forward(
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
"""
|
219 |
-
|
220 |
-
|
221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
"""
|
223 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
if idx.dim() == 1:
|
225 |
-
# Add batch dimension if missing
|
226 |
idx = idx.unsqueeze(0)
|
227 |
-
|
228 |
-
#
|
|
|
|
|
229 |
if self.pipeline_stages is None:
|
230 |
# Single-device forward pass
|
231 |
device = self.token_embedding.weight.device
|
@@ -250,7 +276,11 @@ class ArgonneModel(PreTrainedModel):
|
|
250 |
targets = targets.view(-1)
|
251 |
loss = F.cross_entropy(logits, targets)
|
252 |
|
253 |
-
return
|
|
|
|
|
|
|
|
|
254 |
else:
|
255 |
# Pipeline parallel forward
|
256 |
first_device = next(self.token_embedding.parameters()).device
|
@@ -270,7 +300,7 @@ class ArgonneModel(PreTrainedModel):
|
|
270 |
hidden_states = hidden_states.to(device_stage)
|
271 |
hidden_states = stage(hidden_states)
|
272 |
|
273 |
-
#
|
274 |
hidden_states = hidden_states.to(last_device)
|
275 |
hidden_states = self.ln_f(hidden_states)
|
276 |
logits = self.head(hidden_states)
|
@@ -282,7 +312,11 @@ class ArgonneModel(PreTrainedModel):
|
|
282 |
targets = targets.view(-1)
|
283 |
loss = F.cross_entropy(logits, targets)
|
284 |
|
285 |
-
return
|
|
|
|
|
|
|
|
|
286 |
|
287 |
|
288 |
@torch.no_grad()
|
@@ -342,8 +376,9 @@ class ArgonneModel(PreTrainedModel):
|
|
342 |
generated = generated[:, -self.config.block_size:]
|
343 |
|
344 |
# Forward pass
|
345 |
-
|
346 |
-
logits = logits
|
|
|
347 |
|
348 |
# Temperature
|
349 |
if temperature != 1.0:
|
@@ -382,91 +417,6 @@ class ArgonneModel(PreTrainedModel):
|
|
382 |
|
383 |
return generated
|
384 |
|
385 |
-
|
386 |
-
# @torch.no_grad()
|
387 |
-
# def generate(self, input_ids, max_new_tokens, temperature=0.7, top_k=None, top_p=None, sample=True):
|
388 |
-
# """
|
389 |
-
# Generate text using the model.
|
390 |
-
|
391 |
-
# Args:
|
392 |
-
# input_ids: Input token IDs to continue from
|
393 |
-
# max_new_tokens: Number of tokens to generate
|
394 |
-
# temperature: Temperature for sampling (higher = more random)
|
395 |
-
# top_k: If set, only sample from the top k most likely tokens
|
396 |
-
# top_p: If set, sample from the smallest set of tokens whose cumulative probability exceeds p
|
397 |
-
# sample: If True, sample from the distribution; if False, use greedy decoding
|
398 |
-
|
399 |
-
# Returns:
|
400 |
-
# Tensor containing the input_ids extended with max_new_tokens generated tokens
|
401 |
-
# """
|
402 |
-
# self.eval()
|
403 |
-
|
404 |
-
# # Determine which device to use - explicitly use first device for consistency
|
405 |
-
# if self.pipeline_stages is not None and len(self.devices) > 0:
|
406 |
-
# device = self.devices[0] # Always use first device for generation
|
407 |
-
# else:
|
408 |
-
# device = next(self.parameters()).device
|
409 |
-
|
410 |
-
# # Ensure input is on the correct device
|
411 |
-
# generated = input_ids.to(device)
|
412 |
-
|
413 |
-
# for _ in range(max_new_tokens):
|
414 |
-
# # Truncate if necessary to fit within the model's context window
|
415 |
-
# if generated.shape[1] > self.config.block_size:
|
416 |
-
# generated = generated[:, -self.config.block_size:]
|
417 |
-
|
418 |
-
# # Forward pass
|
419 |
-
# logits, _ = self.forward(generated)
|
420 |
-
|
421 |
-
# # Make sure logits are on the same device
|
422 |
-
# logits = logits.to(device)
|
423 |
-
|
424 |
-
# # Get logits for the last token only
|
425 |
-
# logits = logits[:, -1, :]
|
426 |
-
|
427 |
-
# # Apply temperature
|
428 |
-
# if temperature != 1.0:
|
429 |
-
# logits = logits / temperature
|
430 |
-
|
431 |
-
# # Greedy decoding (argmax) if sample=False
|
432 |
-
# if not sample:
|
433 |
-
# next_token = torch.argmax(logits, dim=-1, keepdim=True)
|
434 |
-
# else:
|
435 |
-
# # Sampling logic
|
436 |
-
# # Apply top-k filtering
|
437 |
-
# if top_k is not None:
|
438 |
-
# indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
439 |
-
# logits = logits.masked_fill(indices_to_remove, float('-inf'))
|
440 |
-
|
441 |
-
# # Apply top-p (nucleus) filtering
|
442 |
-
# if top_p is not None:
|
443 |
-
# sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
444 |
-
# cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
445 |
-
|
446 |
-
# # Remove tokens with cumulative probability above the threshold
|
447 |
-
# sorted_indices_to_remove = cumulative_probs > top_p
|
448 |
-
|
449 |
-
# # Shift the indices to the right to keep the first token above the threshold
|
450 |
-
# sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
451 |
-
# sorted_indices_to_remove[..., 0] = 0
|
452 |
-
|
453 |
-
# indices_to_remove = sorted_indices_to_remove.scatter(
|
454 |
-
# dim=1, index=sorted_indices, src=sorted_indices_to_remove
|
455 |
-
# )
|
456 |
-
# logits = logits.masked_fill(indices_to_remove, float('-inf'))
|
457 |
-
|
458 |
-
# # Convert to probability distribution and sample
|
459 |
-
# probs = F.softmax(logits, dim=-1)
|
460 |
-
# next_token = torch.multinomial(probs, num_samples=1)
|
461 |
-
|
462 |
-
# # Ensure next_token is on the same device before concatenation
|
463 |
-
# next_token = next_token.to(device)
|
464 |
-
|
465 |
-
# # Append the generated token to the sequence
|
466 |
-
# generated = torch.cat((generated, next_token), dim=1)
|
467 |
-
|
468 |
-
# return generated
|
469 |
-
|
470 |
# Register the model with Hugging Face's Auto classes
|
471 |
AutoConfig.register("argonne", ArgonneConfig)
|
472 |
AutoModel.register(ArgonneConfig, ArgonneModel)
|
|
|
9 |
AutoModel,
|
10 |
AutoModelForCausalLM
|
11 |
)
|
12 |
+
from transformers.modeling_outputs import CausalLMOutput
|
13 |
|
14 |
from typing import Optional
|
15 |
|
|
|
103 |
class ArgonneModel(PreTrainedModel):
|
104 |
config_class = ArgonneConfig
|
105 |
|
106 |
+
# for map_device = "auto"
|
107 |
+
_no_split_modules = ["Block"]
|
108 |
+
|
109 |
def __init__(self, config, device_map=None):
|
110 |
super().__init__(config)
|
111 |
# Create embeddings on CPU initially
|
|
|
218 |
# For now, we'll just return self since our model structure should be compatible
|
219 |
return self
|
220 |
|
221 |
+
def forward(
|
222 |
+
self,
|
223 |
+
input_ids=None,
|
224 |
+
attention_mask=None,
|
225 |
+
labels=None,
|
226 |
+
**kwargs
|
227 |
+
):
|
228 |
"""
|
229 |
+
HF-friendly forward method.
|
230 |
+
|
231 |
+
Args:
|
232 |
+
input_ids (torch.LongTensor): Tokens to be fed to the model. [batch_size, seq_len].
|
233 |
+
attention_mask (torch.LongTensor, optional): Mask of shape [batch_size, seq_len],
|
234 |
+
with 1 for actual tokens and 0 for padding, if you want to incorporate it.
|
235 |
+
Currently ignored in this minimal example.
|
236 |
+
labels (torch.LongTensor, optional): Targets for language modeling, same shape as `input_ids`.
|
237 |
+
**kwargs: Catch-all for any additional arguments (e.g. past_key_values) so we don't crash.
|
238 |
"""
|
239 |
+
# 1) We'll rename the parameters from the old code
|
240 |
+
if input_ids is None:
|
241 |
+
raise ValueError("`input_ids` must be provided.")
|
242 |
+
|
243 |
+
# We used to call it 'idx'
|
244 |
+
idx = input_ids
|
245 |
+
# We used to call it 'targets'
|
246 |
+
targets = labels
|
247 |
+
|
248 |
+
# [Optional] If we want to handle single-dim input_ids
|
249 |
if idx.dim() == 1:
|
|
|
250 |
idx = idx.unsqueeze(0)
|
251 |
+
|
252 |
+
# 2) Now the rest of your old forward logic remains, just replacing references
|
253 |
+
# to "idx" and "targets" with these new variables.
|
254 |
+
|
255 |
if self.pipeline_stages is None:
|
256 |
# Single-device forward pass
|
257 |
device = self.token_embedding.weight.device
|
|
|
276 |
targets = targets.view(-1)
|
277 |
loss = F.cross_entropy(logits, targets)
|
278 |
|
279 |
+
return CausalLMOutput(
|
280 |
+
loss=loss,
|
281 |
+
logits=logits,
|
282 |
+
)
|
283 |
+
|
284 |
else:
|
285 |
# Pipeline parallel forward
|
286 |
first_device = next(self.token_embedding.parameters()).device
|
|
|
300 |
hidden_states = hidden_states.to(device_stage)
|
301 |
hidden_states = stage(hidden_states)
|
302 |
|
303 |
+
# Move to last device before final ops
|
304 |
hidden_states = hidden_states.to(last_device)
|
305 |
hidden_states = self.ln_f(hidden_states)
|
306 |
logits = self.head(hidden_states)
|
|
|
312 |
targets = targets.view(-1)
|
313 |
loss = F.cross_entropy(logits, targets)
|
314 |
|
315 |
+
return CausalLMOutput(
|
316 |
+
loss=loss,
|
317 |
+
logits=logits,
|
318 |
+
)
|
319 |
+
|
320 |
|
321 |
|
322 |
@torch.no_grad()
|
|
|
376 |
generated = generated[:, -self.config.block_size:]
|
377 |
|
378 |
# Forward pass
|
379 |
+
outputs = self.forward(generated)
|
380 |
+
logits = outputs.logits # outputs is a CausalLMOutput
|
381 |
+
logits = logits[:, -1, :] # get the last token's logits
|
382 |
|
383 |
# Temperature
|
384 |
if temperature != 1.0:
|
|
|
417 |
|
418 |
return generated
|
419 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
420 |
# Register the model with Hugging Face's Auto classes
|
421 |
AutoConfig.register("argonne", ArgonneConfig)
|
422 |
AutoModel.register(ArgonneConfig, ArgonneModel)
|