Tonic commited on
Commit
2432208
Β·
verified Β·
1 Parent(s): b79fab9

workaround for quantization and push

Browse files
scripts/model_tonic/quantize_model.py CHANGED
@@ -101,14 +101,24 @@ class ModelQuantizer:
101
  return False
102
 
103
  # Check for essential model files
104
- required_files = ['config.json', 'pytorch_model.bin']
105
- optional_files = ['tokenizer.json', 'tokenizer_config.json']
 
 
 
 
 
106
 
107
  missing_required = []
108
  for file in required_files:
109
  if not (self.model_path / file).exists():
110
  missing_required.append(file)
111
 
 
 
 
 
 
112
  if missing_required:
113
  logger.error(f"❌ Missing required model files: {missing_required}")
114
  return False
 
101
  return False
102
 
103
  # Check for essential model files
104
+ required_files = ['config.json']
105
+
106
+ # Check for model files (either safetensors or pytorch)
107
+ model_files = [
108
+ "model.safetensors.index.json", # Safetensors format
109
+ "pytorch_model.bin" # PyTorch format
110
+ ]
111
 
112
  missing_required = []
113
  for file in required_files:
114
  if not (self.model_path / file).exists():
115
  missing_required.append(file)
116
 
117
+ # Check if at least one model file exists
118
+ model_file_exists = any((self.model_path / file).exists() for file in model_files)
119
+ if not model_file_exists:
120
+ missing_required.extend(model_files)
121
+
122
  if missing_required:
123
  logger.error(f"❌ Missing required model files: {missing_required}")
124
  return False
test_safetensors_fix.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test script to verify safetensors model validation fix
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import logging
9
+ from pathlib import Path
10
+
11
+ # Setup logging
12
+ logging.basicConfig(
13
+ level=logging.INFO,
14
+ format='%(asctime)s - %(levelname)s - %(message)s'
15
+ )
16
+ logger = logging.getLogger(__name__)
17
+
18
+ def test_safetensors_validation():
19
+ """Test that safetensors models are properly validated"""
20
+ try:
21
+ from scripts.model_tonic.quantize_model import ModelQuantizer
22
+
23
+ # Test with dummy values
24
+ quantizer = ModelQuantizer(
25
+ model_path="/output-checkpoint",
26
+ repo_name="test/test-repo",
27
+ token="dummy_token"
28
+ )
29
+
30
+ # Mock the model path to simulate the Linux environment
31
+ # In the real environment, this would be /output-checkpoint
32
+ # with safetensors files
33
+
34
+ # Test validation logic
35
+ if quantizer.validate_model_path():
36
+ logger.info("βœ… Safetensors validation test passed")
37
+ return True
38
+ else:
39
+ logger.error("❌ Safetensors validation test failed")
40
+ return False
41
+
42
+ except Exception as e:
43
+ logger.error(f"❌ Safetensors validation test failed: {e}")
44
+ return False
45
+
46
+ def test_model_file_detection():
47
+ """Test model file detection logic"""
48
+ try:
49
+ from scripts.model_tonic.quantize_model import ModelQuantizer
50
+
51
+ quantizer = ModelQuantizer(
52
+ model_path="/output-checkpoint",
53
+ repo_name="test/test-repo",
54
+ token="dummy_token"
55
+ )
56
+
57
+ # Test the validation logic directly
58
+ model_path = Path("/output-checkpoint")
59
+
60
+ # Check for essential files
61
+ required_files = ['config.json']
62
+ model_files = [
63
+ "model.safetensors.index.json", # Safetensors format
64
+ "pytorch_model.bin" # PyTorch format
65
+ ]
66
+
67
+ missing_required = []
68
+ for file in required_files:
69
+ if not (model_path / file).exists():
70
+ missing_required.append(file)
71
+
72
+ # Check if at least one model file exists
73
+ model_file_exists = any((model_path / file).exists() for file in model_files)
74
+ if not model_file_exists:
75
+ missing_required.extend(model_files)
76
+
77
+ if missing_required:
78
+ logger.error(f"❌ Missing required model files: {missing_required}")
79
+ return False
80
+
81
+ logger.info("βœ… Model file detection test passed")
82
+ return True
83
+
84
+ except Exception as e:
85
+ logger.error(f"❌ Model file detection test failed: {e}")
86
+ return False
87
+
88
+ def main():
89
+ """Run safetensors validation tests"""
90
+ logger.info("πŸ§ͺ Testing safetensors validation fix...")
91
+
92
+ tests = [
93
+ ("Safetensors Validation Test", test_safetensors_validation),
94
+ ("Model File Detection Test", test_model_file_detection),
95
+ ]
96
+
97
+ passed = 0
98
+ total = len(tests)
99
+
100
+ for test_name, test_func in tests:
101
+ logger.info(f"\nπŸ” Running {test_name}...")
102
+ if test_func():
103
+ passed += 1
104
+ logger.info(f"βœ… {test_name} passed")
105
+ else:
106
+ logger.error(f"❌ {test_name} failed")
107
+
108
+ logger.info(f"\nπŸ“Š Test Results: {passed}/{total} tests passed")
109
+
110
+ if passed == total:
111
+ logger.info("πŸŽ‰ All safetensors tests passed! The fix should work in the Linux environment.")
112
+ logger.info("πŸ’‘ The validation now properly handles:")
113
+ logger.info(" - Safetensors format (model.safetensors.index.json)")
114
+ logger.info(" - PyTorch format (pytorch_model.bin)")
115
+ logger.info(" - Either format is accepted")
116
+ return 0
117
+ else:
118
+ logger.error("❌ Some tests failed. The fix may need adjustment.")
119
+ return 1
120
+
121
+ if __name__ == "__main__":
122
+ exit(main())