Upload ChessBot Chess model
Browse files- modeling_chessbot.py +0 -71
modeling_chessbot.py
CHANGED
@@ -527,77 +527,6 @@ class ChessBotModel(ChessBotPreTrainedModel):
|
|
527 |
# Initialize weights
|
528 |
self.post_init()
|
529 |
|
530 |
-
@classmethod
|
531 |
-
def from_pretrained(cls, model_path, **kwargs):
|
532 |
-
"""
|
533 |
-
Load a pretrained model from a directory (HuggingFace compatible)
|
534 |
-
"""
|
535 |
-
import os
|
536 |
-
|
537 |
-
# Load config
|
538 |
-
config_path = os.path.join(model_path, "config.json")
|
539 |
-
if os.path.exists(config_path):
|
540 |
-
config = ChessBotConfig.from_pretrained(model_path)
|
541 |
-
else:
|
542 |
-
config = ChessBotConfig()
|
543 |
-
|
544 |
-
# Create model instance
|
545 |
-
model = cls(config)
|
546 |
-
|
547 |
-
# Load weights
|
548 |
-
model_file = None
|
549 |
-
for filename in ["pytorch_model.bin", "model.safetensors"]:
|
550 |
-
full_path = os.path.join(model_path, filename)
|
551 |
-
if os.path.exists(full_path):
|
552 |
-
model_file = full_path
|
553 |
-
break
|
554 |
-
|
555 |
-
if model_file is None:
|
556 |
-
raise FileNotFoundError(f"No model file found in {model_path}")
|
557 |
-
|
558 |
-
if model_file.endswith('.safetensors'):
|
559 |
-
# Handle safetensors format
|
560 |
-
try:
|
561 |
-
from safetensors import safe_open
|
562 |
-
state_dict = {}
|
563 |
-
with safe_open(model_file, framework="pt", device="cpu") as f:
|
564 |
-
for key in f.keys():
|
565 |
-
state_dict[key] = f.get_tensor(key)
|
566 |
-
except ImportError:
|
567 |
-
raise ImportError("safetensors library is required to load .safetensors files. Install with: pip install safetensors")
|
568 |
-
else:
|
569 |
-
# Handle pytorch format
|
570 |
-
state_dict = torch.load(model_file, map_location="cpu")
|
571 |
-
|
572 |
-
# Load state dict into model
|
573 |
-
model.load_state_dict(state_dict, strict=False)
|
574 |
-
|
575 |
-
return model
|
576 |
-
|
577 |
-
def save_pretrained(self, save_directory, safe_serialization=False):
|
578 |
-
"""
|
579 |
-
Save the model to a directory (HuggingFace compatible)
|
580 |
-
"""
|
581 |
-
import os
|
582 |
-
os.makedirs(save_directory, exist_ok=True)
|
583 |
-
|
584 |
-
# Save config
|
585 |
-
self.config.save_pretrained(save_directory)
|
586 |
-
|
587 |
-
# Save model weights
|
588 |
-
if safe_serialization:
|
589 |
-
try:
|
590 |
-
from safetensors.torch import save_file
|
591 |
-
model_path = os.path.join(save_directory, "model.safetensors")
|
592 |
-
save_file(self.state_dict(), model_path)
|
593 |
-
except ImportError:
|
594 |
-
print("⚠ Warning: safetensors not available, falling back to pytorch_model.bin")
|
595 |
-
model_path = os.path.join(save_directory, "pytorch_model.bin")
|
596 |
-
torch.save(self.state_dict(), model_path)
|
597 |
-
else:
|
598 |
-
model_path = os.path.join(save_directory, "pytorch_model.bin")
|
599 |
-
torch.save(self.state_dict(), model_path)
|
600 |
-
|
601 |
def forward(self, input_ids, attention_mask=None, compute_loss=False):
|
602 |
"""
|
603 |
Forward pass compatible with both HuggingFace interface and original interface
|
|
|
527 |
# Initialize weights
|
528 |
self.post_init()
|
529 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
530 |
def forward(self, input_ids, attention_mask=None, compute_loss=False):
|
531 |
"""
|
532 |
Forward pass compatible with both HuggingFace interface and original interface
|