Maxlegrec commited on
Commit
9b6327b
·
verified ·
1 Parent(s): 4d0cb92

Upload ChessBot Chess model

Browse files
Files changed (1) hide show
  1. modeling_chessbot.py +0 -71
modeling_chessbot.py CHANGED
@@ -527,77 +527,6 @@ class ChessBotModel(ChessBotPreTrainedModel):
527
  # Initialize weights
528
  self.post_init()
529
 
530
- @classmethod
531
- def from_pretrained(cls, model_path, **kwargs):
532
- """
533
- Load a pretrained model from a directory (HuggingFace compatible)
534
- """
535
- import os
536
-
537
- # Load config
538
- config_path = os.path.join(model_path, "config.json")
539
- if os.path.exists(config_path):
540
- config = ChessBotConfig.from_pretrained(model_path)
541
- else:
542
- config = ChessBotConfig()
543
-
544
- # Create model instance
545
- model = cls(config)
546
-
547
- # Load weights
548
- model_file = None
549
- for filename in ["pytorch_model.bin", "model.safetensors"]:
550
- full_path = os.path.join(model_path, filename)
551
- if os.path.exists(full_path):
552
- model_file = full_path
553
- break
554
-
555
- if model_file is None:
556
- raise FileNotFoundError(f"No model file found in {model_path}")
557
-
558
- if model_file.endswith('.safetensors'):
559
- # Handle safetensors format
560
- try:
561
- from safetensors import safe_open
562
- state_dict = {}
563
- with safe_open(model_file, framework="pt", device="cpu") as f:
564
- for key in f.keys():
565
- state_dict[key] = f.get_tensor(key)
566
- except ImportError:
567
- raise ImportError("safetensors library is required to load .safetensors files. Install with: pip install safetensors")
568
- else:
569
- # Handle pytorch format
570
- state_dict = torch.load(model_file, map_location="cpu")
571
-
572
- # Load state dict into model
573
- model.load_state_dict(state_dict, strict=False)
574
-
575
- return model
576
-
577
- def save_pretrained(self, save_directory, safe_serialization=False):
578
- """
579
- Save the model to a directory (HuggingFace compatible)
580
- """
581
- import os
582
- os.makedirs(save_directory, exist_ok=True)
583
-
584
- # Save config
585
- self.config.save_pretrained(save_directory)
586
-
587
- # Save model weights
588
- if safe_serialization:
589
- try:
590
- from safetensors.torch import save_file
591
- model_path = os.path.join(save_directory, "model.safetensors")
592
- save_file(self.state_dict(), model_path)
593
- except ImportError:
594
- print("⚠ Warning: safetensors not available, falling back to pytorch_model.bin")
595
- model_path = os.path.join(save_directory, "pytorch_model.bin")
596
- torch.save(self.state_dict(), model_path)
597
- else:
598
- model_path = os.path.join(save_directory, "pytorch_model.bin")
599
- torch.save(self.state_dict(), model_path)
600
-
601
  def forward(self, input_ids, attention_mask=None, compute_loss=False):
602
  """
603
  Forward pass compatible with both HuggingFace interface and original interface
 
527
  # Initialize weights
528
  self.post_init()
529
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
530
  def forward(self, input_ids, attention_mask=None, compute_loss=False):
531
  """
532
  Forward pass compatible with both HuggingFace interface and original interface