anemll commited on
Commit
86baea7
·
verified ·
1 Parent(s): f056d10

Fixed GIL issue

Browse files

Race condition between CoreML and casual_mask update

Files changed (1) hide show
  1. chat.py +323 -204
chat.py CHANGED
@@ -26,8 +26,10 @@ DARK_BLUE = "\033[34m"
26
  LIGHT_GREEN = "\033[92m"
27
  RESET_COLOR = "\033[0m"
28
 
29
- # Add at top with other constants
30
  WARMUP_TOKEN_LIMIT = 10 # Maximum tokens to generate during warmup
 
 
31
 
32
  class TokenPrinter:
33
  """Handles background printing of generated tokens."""
@@ -40,9 +42,12 @@ class TokenPrinter:
40
  self.lock = threading.Lock()
41
  self.thinking = True # Track if we're still in thinking mode
42
  self.decoding_buffer = [] # Buffer for token IDs
43
- # Add token counting and timing
44
  self.start_time = time.time()
45
  self.token_count = 0
 
 
 
46
  self.start()
47
 
48
  def start(self):
@@ -103,15 +108,15 @@ class TokenPrinter:
103
  self.thread.join(timeout=1.0)
104
  except Exception:
105
  pass
106
- # Calculate and print tokens/s with shorter format in blue
107
- elapsed = time.time() - self.start_time
108
- if elapsed > 0 and self.token_count > 0:
109
- tokens_per_sec = self.token_count / elapsed
110
- print(f"\n{DARK_BLUE}{tokens_per_sec:.1f} t/s{RESET_COLOR}")
111
- else:
112
- print(RESET_COLOR) # Reset color at the end
113
  return self.buffer
114
 
 
 
 
 
 
 
115
  def parse_model_path(path):
116
  """Parse model path and return full path with .mlmodelc or .mlpackage extension."""
117
  path = Path(path)
@@ -188,6 +193,89 @@ def load_model(path, function_name=None):
188
  print("\nTry using the .mlpackage version instead, or recompile the model.")
189
  raise
190
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  def load_metadata(model,args):
192
  # Extract metadata and config parameters
193
  metadata = {}
@@ -386,102 +474,99 @@ def make_causal_mask(length, start):
386
  mask[:, :, col_indices <= (row_indices + start)] = 0
387
  return mask
388
 
389
- def run_prefill(embed_model, ffn_models, input_ids, context_pos, context_length, batch_size=64, state=None):
390
  """Run prefill on the input sequence."""
391
- # Create causal mask
392
- causal_mask = make_causal_mask(context_length, 0)
393
- causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
394
 
395
  # Process in batches
396
  batch_pos = 0
397
- while batch_pos < context_pos:
398
- batch_end = min(batch_pos + batch_size, context_pos)
399
  current_batch_size = batch_end - batch_pos
400
 
 
 
401
  # Get current batch
402
  batch_input = input_ids[:, batch_pos:batch_end]
403
 
404
- # Always pad to full batch size for prefill
405
  batch_input = F.pad(
406
  batch_input,
407
  (0, batch_size - current_batch_size),
408
  value=0
409
  )
410
 
411
- # Generate position IDs for full batch size
412
- position_ids = torch.arange(batch_size, dtype=torch.int32) # Changed: Always use full batch size
413
- batch_causal_mask = causal_mask[:, :, :batch_size, :] # Changed: Use full batch size
 
 
414
 
415
  # Run embeddings
416
  hidden_states = torch.from_numpy(
417
  embed_model.predict({'input_ids': batch_input.numpy()})['hidden_states']
418
  )
419
 
420
- # Run through FFN chunks with state
421
  for ffn_model in ffn_models:
422
  if isinstance(ffn_model, dict):
423
  inputs = {
424
- 'hidden_states': hidden_states.numpy(), # [1, 64, hidden_size]
425
- 'position_ids': position_ids.numpy(), # [64]
426
- 'causal_mask': batch_causal_mask.numpy(), # [1, 1, 64, context_length]
427
- 'current_pos': np.array([batch_pos], dtype=np.int32) # [1]
428
  }
429
  output = ffn_model['prefill'].predict(inputs, state)
430
  hidden_states = torch.from_numpy(output['output_hidden_states'])
431
 
432
  batch_pos = batch_end
433
 
434
- return torch.tensor([context_pos], dtype=torch.int32)
435
 
436
- def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, context_length, state=None, temperature=0.0):
437
  """Generate the next token."""
438
  # Get current token
439
- current_token = input_ids[:, pos-1:pos] # [1, 1]
440
 
441
  # Run embeddings
442
  hidden_states = torch.from_numpy(
443
  embed_model.predict({'input_ids': current_token.numpy()})['hidden_states']
444
- ) # [1, 1, hidden_size]
445
 
446
  # Create masks
447
  update_mask = torch.zeros((1, 1, context_length, 1), dtype=torch.float16)
448
  update_mask[0, 0, pos-1, 0] = 1.0
449
- position_ids = torch.tensor([pos-1], dtype=torch.int32) # [1]
450
- causal_mask = make_causal_mask(context_length, 0)
451
- causal_mask = torch.tensor(causal_mask[:, :, pos-1:pos, :], dtype=torch.float16) # [1, 1, 1, context_length]
452
 
453
- # Run through FFN chunks with state
 
 
 
454
  for ffn_model in ffn_models:
455
  if isinstance(ffn_model, dict):
456
  inputs = {
457
  'hidden_states': hidden_states.numpy(),
458
  'update_mask': update_mask.numpy(),
459
  'position_ids': position_ids.numpy(),
460
- 'causal_mask': causal_mask.numpy(),
461
  'current_pos': position_ids.numpy()
462
  }
463
  output = ffn_model['infer'].predict(inputs, state)
464
  hidden_states = torch.from_numpy(output['output_hidden_states'])
465
 
466
- # Run LM head
467
  lm_output = lmhead_model.predict({'hidden_states': hidden_states.numpy()})
468
- # Debug print
469
- #print("\nLM Head output keys:", list(lm_output.keys()))
470
 
471
- # Combine logits1-8 if they exist
472
  if 'logits1' in lm_output:
473
- # Concatenate all logits parts
474
  logits_parts = []
475
  for i in range(1, 9):
476
  key = f'logits{i}'
477
  if key in lm_output:
478
  logits_parts.append(torch.from_numpy(lm_output[key]))
479
- logits = torch.cat(logits_parts, dim=-1) # Concatenate along vocab dimension
480
  else:
481
- # Try output_logits as fallback
482
  logits = torch.from_numpy(lm_output['output_logits'])
483
 
484
- # Apply temperature and sample
485
  if temperature > 0:
486
  logits = logits / temperature
487
  probs = F.softmax(logits[0, -1, :], dim=-1)
@@ -503,36 +588,93 @@ def create_unified_state(ffn_models, context_length):
503
  print("\nCreated unified transformer state")
504
  return state
505
 
506
- def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state, auto_prompt=None, warmup=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
507
  """Interactive chat loop."""
 
508
  context_length = metadata.get('context_length')
509
  batch_size = metadata.get('batch_size', 64)
510
 
511
  if not warmup:
512
  print(f"\nUsing context length: {context_length}")
513
  print("\nStarting chat session. Press Ctrl+D to exit.")
514
- print("Type your message and press Enter to chat.")
515
-
516
- # Check if tokenizer has chat template and if it works
517
- has_chat_template = False
518
- try:
519
- # Test if chat template works
520
- test_messages = [{"role": "user", "content": "test"}]
521
- tokenizer.apply_chat_template(test_messages, return_tensors="pt")
522
- has_chat_template = True
523
- if not warmup:
524
- print("\nUsing chat template for prompts")
525
- except:
526
- if not warmup:
527
- print("\nUsing manual formatting for prompts")
528
 
 
529
  conversation = []
530
 
531
  try:
532
  while True:
533
  try:
534
  if not warmup:
535
- print(f"\n{LIGHT_GREEN}You:{RESET_COLOR}", end=' ', flush=True)
536
  if auto_prompt is not None:
537
  user_input = auto_prompt
538
  if not warmup:
@@ -543,41 +685,69 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
543
  if not warmup:
544
  print("\nExiting chat...")
545
  break
546
-
547
  if not user_input:
548
  continue
 
 
 
 
 
 
549
 
550
- # Format prompt based on tokenizer capabilities
551
- if has_chat_template:
552
- messages = [{"role": "user", "content": user_input}]
553
- input_ids = tokenizer.apply_chat_template(
554
- messages,
 
 
 
 
555
  return_tensors="pt",
556
  add_generation_prompt=True
557
  ).to(torch.int32)
558
  else:
559
- # Manual formatting for Llama models without chat template
560
- formatted_prompt = f"[INST] {user_input} [/INST]"
561
- input_ids = tokenizer(
562
- formatted_prompt,
563
  return_tensors="pt",
564
- add_special_tokens=True
565
- ).input_ids.to(torch.int32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
566
 
567
- context_pos = input_ids.size(1)
 
 
 
 
 
568
 
569
  if not warmup:
570
  print(f"\n{LIGHT_BLUE}Assistant:{RESET_COLOR}", end=' ', flush=True)
571
 
572
- # Initialize token printer
573
  token_printer = TokenPrinter(tokenizer)
574
- tokens_generated = 0 # Track number of tokens
 
575
 
576
  try:
577
- # Start prefill timing
578
- prefill_start = time.time()
579
-
580
- # Run prefill with state
581
  current_pos = run_prefill(
582
  embed_model,
583
  ffn_models,
@@ -585,21 +755,53 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
585
  context_pos,
586
  context_length,
587
  batch_size,
588
- state
 
589
  )
 
590
 
591
- # Calculate prefill timing
592
- prefill_time = time.time() - prefill_start
593
- prefill_tokens = context_pos # Number of tokens in input
594
- prefill_tokens_per_sec = prefill_tokens / prefill_time if prefill_time > 0 else 0
595
-
596
- # Generation loop with state
597
- input_ids = input_ids
598
  pos = context_pos
599
- inference_start = time.time()
600
- inference_tokens = 0
601
 
602
- while pos < context_length - 1:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
603
  # Generate next token
604
  next_token = generate_next_token(
605
  embed_model,
@@ -608,146 +810,58 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
608
  input_ids,
609
  pos,
610
  context_length,
611
- state
 
612
  )
613
 
614
- # Add token to sequence
615
- if pos < input_ids.size(1):
616
- input_ids[0, pos] = next_token
617
- else:
618
- input_ids = torch.cat([
619
- input_ids,
620
- torch.tensor([[next_token]], dtype=torch.int32)
621
- ], dim=1)
622
-
623
- # Add to printer only if not in warmup
624
  if not warmup:
625
  token_printer.add_token(next_token)
626
  token_printer.drain_buffer()
 
627
 
628
  pos += 1
629
  tokens_generated += 1
630
- inference_tokens += 1
631
 
632
- # Check limits
633
  if warmup and tokens_generated >= WARMUP_TOKEN_LIMIT:
634
  break
635
-
636
  if next_token == tokenizer.eos_token_id:
637
  break
638
 
639
- # Calculate inference timing
640
- inference_time = time.time() - inference_start
641
- inference_tokens_per_sec = inference_tokens / inference_time if inference_time > 0 else 0
 
 
642
 
643
- # Get final response and add to conversation
644
  if not warmup:
645
- response = token_printer.stop()
646
- # Print timing stats
647
- prefill_ms = prefill_time * 1000 # Convert to milliseconds
648
- print(f"\nPrefill: {prefill_ms:.1f}ms ({prefill_tokens_per_sec:.1f} t/s)")
649
- print(f"Inference: {inference_tokens_per_sec:.1f} t/s")
650
- print(f"Total: Generated {tokens_generated} tokens in {prefill_time + inference_time:.2f}s")
651
- conversation.append({"role": "assistant", "content": response})
652
- else:
653
- token_printer.stop() # Clean up without printing stats
654
 
655
- # Exit after one response in auto_prompt mode
656
  if auto_prompt is not None:
657
  break
658
 
659
  except KeyboardInterrupt:
660
- print("\nGeneration interrupted")
 
661
  token_printer.stop()
662
  continue
663
 
664
  except Exception as e:
665
- print(f"\nError in chat loop: {str(e)}")
666
- import traceback
667
- traceback.print_exc()
668
-
669
- def parse_args():
670
- parser = argparse.ArgumentParser(description='Chat with CoreML LLaMA (c) 2025 Anemll')
671
-
672
- # Add meta.yaml option
673
- parser.add_argument('--meta', type=str, help='Path to meta.yaml to load all parameters')
674
-
675
- # Model paths
676
- parser.add_argument('--d', '--dir', type=str, default='.',
677
- help='Directory containing model files (default: current directory)')
678
- parser.add_argument('--embed', type=str, required=False,
679
- help='Path to embeddings model (relative to --dir)')
680
- parser.add_argument('--ffn', type=str, required=False,
681
- help='Path to FFN model (can be chunked, relative to --dir)')
682
- parser.add_argument('--lmhead', type=str, required=False,
683
- help='Path to LM head model (relative to --dir)')
684
- parser.add_argument('--tokenizer', type=str, required=False,
685
- help='Path to tokenizer')
686
-
687
- # Add new argument for auto-generation
688
- parser.add_argument('--prompt', type=str,
689
- help='If specified, run once with this prompt and exit')
690
-
691
- # Add no-warmup flag
692
- parser.add_argument('--nw', action='store_true',
693
- help='Skip warmup phase')
694
-
695
- # Model configuration
696
- parser.add_argument('--context-length', type=int,
697
- help='Context length for the model (default: 512), if not provided, it will be detected from the model directory name ctxNUMBER')
698
- parser.add_argument('--batch-size', type=int,
699
- help='Batch size for prefill (default: 64)')
700
-
701
- args = parser.parse_args()
702
-
703
- # If meta.yaml is provided, load parameters from it
704
- if args.meta:
705
- try:
706
- with open(args.meta, 'r') as f:
707
- meta = yaml.safe_load(f)
708
- params = meta['model_info']['parameters']
709
-
710
- # Set model directory to meta.yaml directory if not specified
711
- if not args.d or args.d == '.':
712
- args.d = str(Path(args.meta).parent)
713
-
714
- # Build model paths based on parameters
715
- prefix = params.get('model_prefix', 'llama') # Default to 'llama' if not specified
716
- lut_ffn = f"_lut{params['lut_ffn']}" if params['lut_ffn'] != 'none' else ''
717
- lut_lmhead = f"_lut{params['lut_lmhead']}" if params['lut_lmhead'] != 'none' else ''
718
- num_chunks = int(params['num_chunks'])
719
-
720
- # Set model paths if not specified
721
- if not args.embed:
722
- args.embed = f'{prefix}_embeddings'
723
- if not args.lmhead:
724
- args.lmhead = f'{prefix}_lm_head{lut_lmhead}'
725
- if not args.ffn:
726
- args.ffn = f'{prefix}_FFN_PF{lut_ffn}_chunk_01of{num_chunks:02d}'
727
- if not args.tokenizer:
728
- args.tokenizer = args.d
729
-
730
- # Set other parameters if not overridden by command line
731
- if args.context_length is None:
732
- args.context_length = int(params['context_length'])
733
- if args.batch_size is None:
734
- args.batch_size = int(params['batch_size'])
735
- args.num_chunks = num_chunks
736
-
737
- print(f"\nLoaded parameters from {args.meta}:")
738
- print(f" Context Length: {args.context_length}")
739
- print(f" Batch Size: {args.batch_size}")
740
- print(f" Num Chunks: {args.num_chunks}")
741
- print(f" Models Directory: {args.d}")
742
- print(f" Embeddings: {args.embed}")
743
- print(f" LM Head: {args.lmhead}")
744
- print(f" FFN: {args.ffn}")
745
-
746
- except Exception as e:
747
- print(f"\nError loading meta.yaml: {str(e)}")
748
- sys.exit(1)
749
-
750
- return args
751
 
752
  def main():
753
  args = parse_args()
@@ -800,6 +914,9 @@ def main():
800
  # Create unified state once
801
  state = create_unified_state(ffn_models, metadata['context_length'])
802
 
 
 
 
803
  # Warmup runs to prevent Python GIL issues with CoreML !
804
  if not args.nw:
805
  for i in range(2):
@@ -809,7 +926,8 @@ def main():
809
  lmhead_model=lmhead_model,
810
  tokenizer=tokenizer,
811
  metadata=metadata,
812
- state=state,
 
813
  warmup=True,
814
  auto_prompt="who are you?"
815
  )
@@ -821,7 +939,8 @@ def main():
821
  lmhead_model=lmhead_model,
822
  tokenizer=tokenizer,
823
  metadata=metadata,
824
- state=state,
 
825
  warmup=False,
826
  auto_prompt=args.prompt
827
  )
 
26
  LIGHT_GREEN = "\033[92m"
27
  RESET_COLOR = "\033[0m"
28
 
29
+ # Add at the top with other constants
30
  WARMUP_TOKEN_LIMIT = 10 # Maximum tokens to generate during warmup
31
+ THINKING_MODE = False
32
+ THINKING_PROMPT = """You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering. You should enclose your thoughts and internal monologue inside <think> </think> tags, and then provide your solution or response to the problem."""
33
 
34
  class TokenPrinter:
35
  """Handles background printing of generated tokens."""
 
42
  self.lock = threading.Lock()
43
  self.thinking = True # Track if we're still in thinking mode
44
  self.decoding_buffer = [] # Buffer for token IDs
45
+ # Timing and stats tracking
46
  self.start_time = time.time()
47
  self.token_count = 0
48
+ self.prefill_time = 0
49
+ self.inference_time = 0
50
+ self.context_pos = 0
51
  self.start()
52
 
53
  def start(self):
 
108
  self.thread.join(timeout=1.0)
109
  except Exception:
110
  pass
111
+ print(RESET_COLOR) # Reset color at the end
 
 
 
 
 
 
112
  return self.buffer
113
 
114
+ def set_timing(self, prefill_time, inference_time, context_pos):
115
+ """Set timing information."""
116
+ self.prefill_time = prefill_time
117
+ self.inference_time = inference_time
118
+ self.context_pos = context_pos
119
+
120
  def parse_model_path(path):
121
  """Parse model path and return full path with .mlmodelc or .mlpackage extension."""
122
  path = Path(path)
 
193
  print("\nTry using the .mlpackage version instead, or recompile the model.")
194
  raise
195
 
196
+ def parse_args():
197
+ parser = argparse.ArgumentParser(description='Full Chat with CoreML LLaMA with context window shifting, gil resolved (c) 2025 Anemll')
198
+
199
+ # Add meta.yaml option
200
+ parser.add_argument('--meta', type=str, help='Path to meta.yaml to load all parameters')
201
+
202
+ # Add existing arguments
203
+ parser.add_argument('--d', '--dir', type=str, default='.',
204
+ help='Directory containing model files (default: current directory)')
205
+ parser.add_argument('--embed', type=str, required=False,
206
+ help='Path to embeddings model (relative to --dir)')
207
+ parser.add_argument('--ffn', type=str, required=False,
208
+ help='Path to FFN model (can be chunked, relative to --dir)')
209
+ parser.add_argument('--lmhead', type=str, required=False,
210
+ help='Path to LM head model (relative to --dir)')
211
+ parser.add_argument('--tokenizer', type=str, required=False,
212
+ help='Path to tokenizer')
213
+
214
+ # Add new argument for auto-generation
215
+ parser.add_argument('--prompt', type=str,
216
+ help='If specified, run once with this prompt and exit')
217
+
218
+ # Add no-warmup flag
219
+ parser.add_argument('--nw', action='store_true',
220
+ help='Skip warmup phase')
221
+
222
+ # Model configuration
223
+ parser.add_argument('--context-length', type=int,
224
+ help='Context length for the model (default: 512), if not provided, it will be detected from the model directory name ctxNUMBER')
225
+ parser.add_argument('--batch-size', type=int,
226
+ help='Batch size for prefill (default: 64)')
227
+
228
+ args = parser.parse_args()
229
+
230
+ # If meta.yaml is provided, load parameters from it
231
+ if args.meta:
232
+ try:
233
+ with open(args.meta, 'r') as f:
234
+ meta = yaml.safe_load(f)
235
+ params = meta['model_info']['parameters']
236
+
237
+ # Set model directory to meta.yaml directory if not specified
238
+ if not args.d or args.d == '.':
239
+ args.d = str(Path(args.meta).parent)
240
+
241
+ # Build model paths based on parameters
242
+ prefix = params.get('model_prefix', 'llama') # Default to 'llama' if not specified
243
+ lut_ffn = f"_lut{params['lut_ffn']}" if params['lut_ffn'] != 'none' else ''
244
+ lut_lmhead = f"_lut{params['lut_lmhead']}" if params['lut_lmhead'] != 'none' else ''
245
+ num_chunks = int(params['num_chunks'])
246
+
247
+ # Set model paths if not specified
248
+ if not args.embed:
249
+ args.embed = f'{prefix}_embeddings'
250
+ if not args.lmhead:
251
+ args.lmhead = f'{prefix}_lm_head{lut_lmhead}'
252
+ if not args.ffn:
253
+ args.ffn = f'{prefix}_FFN_PF{lut_ffn}_chunk_01of{num_chunks:02d}'
254
+ if not args.tokenizer:
255
+ args.tokenizer = args.d
256
+
257
+ # Set other parameters if not overridden by command line
258
+ if args.context_length is None:
259
+ args.context_length = int(params['context_length'])
260
+ if args.batch_size is None:
261
+ args.batch_size = int(params['batch_size'])
262
+ args.num_chunks = num_chunks
263
+
264
+ print(f"\nLoaded parameters from {args.meta}:")
265
+ print(f" Context Length: {args.context_length}")
266
+ print(f" Batch Size: {args.batch_size}")
267
+ print(f" Num Chunks: {args.num_chunks}")
268
+ print(f" Models Directory: {args.d}")
269
+ print(f" Embeddings: {args.embed}")
270
+ print(f" LM Head: {args.lmhead}")
271
+ print(f" FFN: {args.ffn}")
272
+
273
+ except Exception as e:
274
+ print(f"\nError loading meta.yaml: {str(e)}")
275
+ sys.exit(1)
276
+
277
+ return args
278
+
279
  def load_metadata(model,args):
280
  # Extract metadata and config parameters
281
  metadata = {}
 
474
  mask[:, :, col_indices <= (row_indices + start)] = 0
475
  return mask
476
 
477
+ def run_prefill(embed_model, ffn_models, input_ids, current_pos, context_length, batch_size, state, causal_mask):
478
  """Run prefill on the input sequence."""
479
+ #print(f"[DEBUG] Running prefill from 0 to {current_pos}")
 
 
480
 
481
  # Process in batches
482
  batch_pos = 0
483
+ while batch_pos < current_pos:
484
+ batch_end = min(batch_pos + batch_size, current_pos)
485
  current_batch_size = batch_end - batch_pos
486
 
487
+ #print(f"[DEBUG] Prefill batch {batch_pos}-{batch_end} (size={current_batch_size})")
488
+
489
  # Get current batch
490
  batch_input = input_ids[:, batch_pos:batch_end]
491
 
492
+ # Pad to full batch size
493
  batch_input = F.pad(
494
  batch_input,
495
  (0, batch_size - current_batch_size),
496
  value=0
497
  )
498
 
499
+ # Generate position IDs for this batch
500
+ position_ids = torch.arange(batch_pos, batch_pos + batch_size, dtype=torch.int32)
501
+
502
+ # Use the pre-initialized causal mask and extract the batch portion
503
+ batch_causal_mask = causal_mask[:, :, batch_pos:batch_pos + batch_size, :]
504
 
505
  # Run embeddings
506
  hidden_states = torch.from_numpy(
507
  embed_model.predict({'input_ids': batch_input.numpy()})['hidden_states']
508
  )
509
 
510
+ # Run through FFN chunks
511
  for ffn_model in ffn_models:
512
  if isinstance(ffn_model, dict):
513
  inputs = {
514
+ 'hidden_states': hidden_states.numpy(),
515
+ 'position_ids': position_ids.numpy(),
516
+ 'causal_mask': batch_causal_mask.numpy(),
517
+ 'current_pos': np.array([batch_pos], dtype=np.int32)
518
  }
519
  output = ffn_model['prefill'].predict(inputs, state)
520
  hidden_states = torch.from_numpy(output['output_hidden_states'])
521
 
522
  batch_pos = batch_end
523
 
524
+ return torch.tensor([current_pos], dtype=torch.int32)
525
 
526
+ def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, context_length, state, causal_mask, temperature=0.0):
527
  """Generate the next token."""
528
  # Get current token
529
+ current_token = input_ids[:, pos-1:pos]
530
 
531
  # Run embeddings
532
  hidden_states = torch.from_numpy(
533
  embed_model.predict({'input_ids': current_token.numpy()})['hidden_states']
534
+ )
535
 
536
  # Create masks
537
  update_mask = torch.zeros((1, 1, context_length, 1), dtype=torch.float16)
538
  update_mask[0, 0, pos-1, 0] = 1.0
539
+ position_ids = torch.tensor([pos-1], dtype=torch.int32)
 
 
540
 
541
+ # Use the pre-initialized causal mask and extract the single position portion
542
+ single_causal_mask = causal_mask[:, :, pos-1:pos, :]
543
+
544
+ # Run through FFN chunks
545
  for ffn_model in ffn_models:
546
  if isinstance(ffn_model, dict):
547
  inputs = {
548
  'hidden_states': hidden_states.numpy(),
549
  'update_mask': update_mask.numpy(),
550
  'position_ids': position_ids.numpy(),
551
+ 'causal_mask': single_causal_mask.numpy(),
552
  'current_pos': position_ids.numpy()
553
  }
554
  output = ffn_model['infer'].predict(inputs, state)
555
  hidden_states = torch.from_numpy(output['output_hidden_states'])
556
 
557
+ # Run LM head and get next token
558
  lm_output = lmhead_model.predict({'hidden_states': hidden_states.numpy()})
 
 
559
 
 
560
  if 'logits1' in lm_output:
 
561
  logits_parts = []
562
  for i in range(1, 9):
563
  key = f'logits{i}'
564
  if key in lm_output:
565
  logits_parts.append(torch.from_numpy(lm_output[key]))
566
+ logits = torch.cat(logits_parts, dim=-1)
567
  else:
 
568
  logits = torch.from_numpy(lm_output['output_logits'])
569
 
 
570
  if temperature > 0:
571
  logits = logits / temperature
572
  probs = F.softmax(logits[0, -1, :], dim=-1)
 
588
  print("\nCreated unified transformer state")
589
  return state
590
 
591
+ def initialize_causal_mask(context_length):
592
+ """Initialize causal mask for transformer attention."""
593
+ causal_mask = make_causal_mask(context_length, 0)
594
+ causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
595
+ print(f"\nInitialized causal mask for context length {context_length}")
596
+ return causal_mask
597
+
598
+ def get_user_input():
599
+ """Get input from user, handling special key combinations."""
600
+ global THINKING_MODE
601
+ try:
602
+ import termios
603
+ import tty
604
+ import sys
605
+
606
+ def _getch():
607
+ fd = sys.stdin.fileno()
608
+ old_settings = termios.tcgetattr(fd)
609
+ try:
610
+ tty.setraw(sys.stdin.fileno())
611
+ ch = sys.stdin.read(1)
612
+ finally:
613
+ termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
614
+ return ch
615
+
616
+ buffer = []
617
+ while True:
618
+ char = _getch()
619
+
620
+ # Debug: print the character code
621
+ print(f"\nKey pressed: {repr(char)} (hex: {hex(ord(char))})")
622
+
623
+ # Check for Enter key
624
+ if char == '\r' or char == '\n':
625
+ print() # Move to next line
626
+ input_text = ''.join(buffer)
627
+ # Check if the command is /t
628
+ if input_text == '/t':
629
+ THINKING_MODE = not THINKING_MODE
630
+ print(f"Thinking mode {'ON' if THINKING_MODE else 'OFF'}")
631
+ buffer = [] # Clear buffer
632
+ print(f"\n{LIGHT_GREEN}You{' (thinking)' if THINKING_MODE else ''}:{RESET_COLOR}", end=' ', flush=True)
633
+ continue
634
+ return input_text
635
+
636
+ # Handle backspace
637
+ if char == '\x7f': # backspace
638
+ if buffer:
639
+ buffer.pop()
640
+ sys.stdout.write('\b \b') # Erase character
641
+ sys.stdout.flush()
642
+ continue
643
+
644
+ # Handle Ctrl-C
645
+ if char == '\x03': # Ctrl-C
646
+ print("^C")
647
+ raise KeyboardInterrupt
648
+
649
+ # Print character and add to buffer
650
+ sys.stdout.write(char)
651
+ sys.stdout.flush()
652
+ buffer.append(char)
653
+
654
+ except ImportError:
655
+ # Fallback for systems without termios
656
+ return input("> ")
657
+
658
+ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state, causal_mask, auto_prompt=None, warmup=False):
659
  """Interactive chat loop."""
660
+ global THINKING_MODE
661
  context_length = metadata.get('context_length')
662
  batch_size = metadata.get('batch_size', 64)
663
 
664
  if not warmup:
665
  print(f"\nUsing context length: {context_length}")
666
  print("\nStarting chat session. Press Ctrl+D to exit.")
667
+ print("Type your message and press Enter to chat. Use /t to toggle thinking mode.")
668
+ print(f"Thinking mode is {'ON' if THINKING_MODE else 'OFF'}")
 
 
 
 
 
 
 
 
 
 
 
 
669
 
670
+ # Keep track of conversation history
671
  conversation = []
672
 
673
  try:
674
  while True:
675
  try:
676
  if not warmup:
677
+ print(f"\n{LIGHT_GREEN}You{' (thinking)' if THINKING_MODE else ''}:{RESET_COLOR}", end=' ', flush=True)
678
  if auto_prompt is not None:
679
  user_input = auto_prompt
680
  if not warmup:
 
685
  if not warmup:
686
  print("\nExiting chat...")
687
  break
688
+
689
  if not user_input:
690
  continue
691
+
692
+ # Handle /t command
693
+ if user_input == "/t":
694
+ THINKING_MODE = not THINKING_MODE
695
+ print(f"Thinking mode {'ON' if THINKING_MODE else 'OFF'}")
696
+ continue
697
 
698
+ # Add user message to conversation
699
+ conversation.append({"role": "user", "content": user_input})
700
+
701
+ # Format using chat template with full history
702
+ if THINKING_MODE:
703
+ # Add thinking prompt to system message
704
+ conversation_with_thinking = [{"role": "system", "content": THINKING_PROMPT}] + conversation
705
+ base_input_ids = tokenizer.apply_chat_template(
706
+ conversation_with_thinking,
707
  return_tensors="pt",
708
  add_generation_prompt=True
709
  ).to(torch.int32)
710
  else:
711
+ base_input_ids = tokenizer.apply_chat_template(
712
+ conversation,
 
 
713
  return_tensors="pt",
714
+ add_generation_prompt=True
715
+ ).to(torch.int32)
716
+
717
+ # Check if we need to trim history
718
+ while base_input_ids.size(1) > context_length - 100: # Leave room for response
719
+ # Remove oldest message pair (user + assistant)
720
+ if len(conversation) > 2:
721
+ conversation = conversation[2:] # Remove oldest pair
722
+ base_input_ids = tokenizer.apply_chat_template(
723
+ conversation,
724
+ return_tensors="pt",
725
+ add_generation_prompt=True
726
+ ).to(torch.int32)
727
+ else:
728
+ # If only current message remains and still too long, truncate
729
+ base_input_ids = base_input_ids[:, -context_length//2:]
730
+ break
731
+
732
+ context_pos = base_input_ids.size(1)
733
 
734
+ # Pad sequence to context_size
735
+ input_ids = F.pad(
736
+ base_input_ids,
737
+ (0, context_length - context_pos),
738
+ value=0
739
+ )
740
 
741
  if not warmup:
742
  print(f"\n{LIGHT_BLUE}Assistant:{RESET_COLOR}", end=' ', flush=True)
743
 
744
+ # Initialize token printer and collect response
745
  token_printer = TokenPrinter(tokenizer)
746
+ response_tokens = []
747
+ generation_start_time = time.time()
748
 
749
  try:
750
+ # Run prefill on entire context
 
 
 
751
  current_pos = run_prefill(
752
  embed_model,
753
  ffn_models,
 
755
  context_pos,
756
  context_length,
757
  batch_size,
758
+ state,
759
+ causal_mask
760
  )
761
+ #print(f"\n[DEBUG] After initial prefill - current_pos: {current_pos}")
762
 
763
+ # Generation loop
 
 
 
 
 
 
764
  pos = context_pos
765
+ tokens_generated = 0
766
+ inference_start = time.time() # Start inference timing
767
 
768
+ while True:
769
+ # Check if we need to shift window
770
+ if pos >= context_length - 2:
771
+ # Calculate shift to maintain full batches
772
+ batch_size = metadata.get('batch_size', 64)
773
+ # Calculate max batches that fit in context
774
+ max_batches = context_length // batch_size
775
+ desired_batches = max(1, max_batches - 2) # Leave room for new tokens
776
+ new_size = min(desired_batches * batch_size, context_length - batch_size)
777
+
778
+ # Create shifted input_ids
779
+ tmp = torch.zeros((1, context_length), dtype=torch.int32)
780
+ tmp[:,0:new_size] = input_ids[:,pos-new_size:pos]
781
+ input_ids = tmp
782
+
783
+ # Reset state and run prefill
784
+ # keep the same state
785
+ #state = create_unified_state(ffn_models, context_length)
786
+ current_pos = run_prefill(
787
+ embed_model,
788
+ ffn_models,
789
+ input_ids,
790
+ new_size, # Prefill the entire shifted content
791
+ context_length,
792
+ batch_size,
793
+ state,
794
+ causal_mask
795
+ )
796
+
797
+ # Start generating from the next position
798
+ pos = new_size # Don't back up, continue from where we left off
799
+
800
+ #print(f"\n[DEBUG] After shift - next token will be at pos {pos}")
801
+ #print(f"[DEBUG] Context before next token: {tokenizer.decode(input_ids[0, pos-40:pos])}")
802
+
803
+ window_shifted = True
804
+
805
  # Generate next token
806
  next_token = generate_next_token(
807
  embed_model,
 
810
  input_ids,
811
  pos,
812
  context_length,
813
+ state,
814
+ causal_mask
815
  )
816
 
817
+ # Add token
818
+ input_ids[0, pos] = next_token
 
 
 
 
 
 
 
 
819
  if not warmup:
820
  token_printer.add_token(next_token)
821
  token_printer.drain_buffer()
822
+ response_tokens.append(next_token)
823
 
824
  pos += 1
825
  tokens_generated += 1
 
826
 
827
+ # In warmup mode, limit tokens
828
  if warmup and tokens_generated >= WARMUP_TOKEN_LIMIT:
829
  break
830
+
831
  if next_token == tokenizer.eos_token_id:
832
  break
833
 
834
+ inference_time = time.time() - inference_start # Calculate inference time
835
+
836
+ # Add assistant response to conversation
837
+ response_text = token_printer.stop()
838
+ conversation.append({"role": "assistant", "content": response_text})
839
 
840
+ # Print stats only if not in warmup
841
  if not warmup:
842
+ total_time = time.time() - generation_start_time
843
+ prefill_time = total_time - inference_time
844
+ inference_tokens_per_sec = len(response_tokens) / inference_time if inference_time > 0 else 0
845
+ prefill_ms = prefill_time * 1000
846
+ prefill_tokens_per_sec = context_pos / prefill_time if prefill_time > 0 else 0
847
+ print(f"{DARK_BLUE}{inference_tokens_per_sec:.1f} t/s, "
848
+ f"TTFT: {prefill_ms:.1f}ms ({prefill_tokens_per_sec:.1f} t/s), "
849
+ f"{len(response_tokens)} tokens{RESET_COLOR}")
 
850
 
 
851
  if auto_prompt is not None:
852
  break
853
 
854
  except KeyboardInterrupt:
855
+ if not warmup:
856
+ print("\nGeneration interrupted")
857
  token_printer.stop()
858
  continue
859
 
860
  except Exception as e:
861
+ if not warmup:
862
+ print(f"\nError in chat loop: {str(e)}")
863
+ import traceback
864
+ traceback.print_exc()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
865
 
866
  def main():
867
  args = parse_args()
 
914
  # Create unified state once
915
  state = create_unified_state(ffn_models, metadata['context_length'])
916
 
917
+ # Initialize causal mask once
918
+ causal_mask = initialize_causal_mask(metadata['context_length'])
919
+
920
  # Warmup runs to prevent Python GIL issues with CoreML !
921
  if not args.nw:
922
  for i in range(2):
 
926
  lmhead_model=lmhead_model,
927
  tokenizer=tokenizer,
928
  metadata=metadata,
929
+ state=state, # Pass the state
930
+ causal_mask=causal_mask, # Pass the causal mask
931
  warmup=True,
932
  auto_prompt="who are you?"
933
  )
 
939
  lmhead_model=lmhead_model,
940
  tokenizer=tokenizer,
941
  metadata=metadata,
942
+ state=state, # Pass the state
943
+ causal_mask=causal_mask, # Pass the causal mask
944
  warmup=False,
945
  auto_prompt=args.prompt
946
  )