Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Comprehensive TRL compatibility test | |
| Verifies all TRL interface requirements are met | |
| """ | |
| import sys | |
| import os | |
| sys.path.append(os.path.dirname(os.path.abspath(__file__))) | |
| def test_core_interface(): | |
| """Test core TRL interface requirements""" | |
| print("π§ͺ Testing Core TRL Interface...") | |
| try: | |
| import trackio | |
| # Test 1: Core functions exist | |
| required_functions = ['init', 'log', 'finish'] | |
| for func_name in required_functions: | |
| assert hasattr(trackio, func_name), f"trackio.{func_name} not found" | |
| print(f"β trackio.{func_name} exists") | |
| # Test 2: Config attribute exists | |
| assert hasattr(trackio, 'config'), "trackio.config not found" | |
| print("β trackio.config exists") | |
| # Test 3: Config has update method | |
| config = trackio.config | |
| assert hasattr(config, 'update'), "trackio.config.update not found" | |
| print("β trackio.config.update exists") | |
| return True | |
| except Exception as e: | |
| print(f"β Core interface test failed: {e}") | |
| return False | |
| def test_init_functionality(): | |
| """Test init function with various argument patterns""" | |
| print("\nπ§ Testing Init Functionality...") | |
| try: | |
| import trackio | |
| # Test 1: No arguments (TRL compatibility) | |
| try: | |
| experiment_id = trackio.init() | |
| print(f"β trackio.init() without args: {experiment_id}") | |
| except Exception as e: | |
| print(f"β trackio.init() without args failed: {e}") | |
| return False | |
| # Test 2: With arguments | |
| try: | |
| experiment_id = trackio.init(project_name="test_project", experiment_name="test_exp") | |
| print(f"β trackio.init() with args: {experiment_id}") | |
| except Exception as e: | |
| print(f"β trackio.init() with args failed: {e}") | |
| return False | |
| # Test 3: With kwargs | |
| try: | |
| experiment_id = trackio.init(test_param="test_value") | |
| print(f"β trackio.init() with kwargs: {experiment_id}") | |
| except Exception as e: | |
| print(f"β trackio.init() with kwargs failed: {e}") | |
| return False | |
| return True | |
| except Exception as e: | |
| print(f"β Init functionality test failed: {e}") | |
| return False | |
| def test_log_functionality(): | |
| """Test log function with various metric types""" | |
| print("\nπ Testing Log Functionality...") | |
| try: | |
| import trackio | |
| # Test 1: Basic metrics | |
| try: | |
| trackio.log({'loss': 0.5, 'accuracy': 0.8}) | |
| print("β trackio.log() with basic metrics") | |
| except Exception as e: | |
| print(f"β trackio.log() with basic metrics failed: {e}") | |
| return False | |
| # Test 2: With step parameter | |
| try: | |
| trackio.log({'loss': 0.4, 'lr': 1e-4}, step=100) | |
| print("β trackio.log() with step parameter") | |
| except Exception as e: | |
| print(f"β trackio.log() with step failed: {e}") | |
| return False | |
| # Test 3: TRL-specific metrics | |
| try: | |
| trackio.log({ | |
| 'total_tokens': 1000, | |
| 'truncated_tokens': 50, | |
| 'padding_tokens': 20, | |
| 'throughput': 100.5, | |
| 'step_time': 0.1 | |
| }) | |
| print("β trackio.log() with TRL-specific metrics") | |
| except Exception as e: | |
| print(f"β trackio.log() with TRL metrics failed: {e}") | |
| return False | |
| return True | |
| except Exception as e: | |
| print(f"β Log functionality test failed: {e}") | |
| return False | |
| def test_config_update(): | |
| """Test config update with TRL-specific patterns""" | |
| print("\nβοΈ Testing Config Update...") | |
| try: | |
| import trackio | |
| config = trackio.config | |
| # Test 1: TRL-specific keyword arguments | |
| try: | |
| config.update(allow_val_change=True, project_name="trl_test") | |
| print(f"β Config update with TRL kwargs: allow_val_change={config.allow_val_change}") | |
| except Exception as e: | |
| print(f"β Config update with TRL kwargs failed: {e}") | |
| return False | |
| # Test 2: Dictionary update | |
| try: | |
| config.update({'experiment_name': 'test_exp', 'new_param': 'value'}) | |
| print(f"β Config update with dict: experiment_name={config.experiment_name}") | |
| except Exception as e: | |
| print(f"β Config update with dict failed: {e}") | |
| return False | |
| # Test 3: Mixed update | |
| try: | |
| config.update({'mixed_param': 'dict_value'}, kwarg_param='keyword_value') | |
| print(f"β Config update with mixed args: mixed_param={config.mixed_param}, kwarg_param={config.kwarg_param}") | |
| except Exception as e: | |
| print(f"β Config update with mixed args failed: {e}") | |
| return False | |
| return True | |
| except Exception as e: | |
| print(f"β Config update test failed: {e}") | |
| return False | |
| def test_finish_functionality(): | |
| """Test finish function""" | |
| print("\nπ Testing Finish Functionality...") | |
| try: | |
| import trackio | |
| # Test finish function | |
| try: | |
| trackio.finish() | |
| print("β trackio.finish() completed successfully") | |
| except Exception as e: | |
| print(f"β trackio.finish() failed: {e}") | |
| return False | |
| return True | |
| except Exception as e: | |
| print(f"β Finish functionality test failed: {e}") | |
| return False | |
| def test_trl_trainer_simulation(): | |
| """Simulate TRL trainer usage patterns""" | |
| print("\nπ€ Testing TRL Trainer Simulation...") | |
| try: | |
| import trackio | |
| # Simulate SFTTrainer initialization | |
| try: | |
| # Initialize trackio (like TRL does) | |
| experiment_id = trackio.init() | |
| print(f"β TRL-style initialization: {experiment_id}") | |
| # Update config (like TRL does) | |
| trackio.config.update(allow_val_change=True, project_name="trl_simulation") | |
| print("β TRL-style config update") | |
| # Log metrics (like TRL does during training) | |
| for step in range(1, 4): | |
| trackio.log({ | |
| 'loss': 1.0 / step, | |
| 'learning_rate': 1e-4, | |
| 'total_tokens': step * 1000, | |
| 'throughput': 100.0 / step | |
| }, step=step) | |
| print(f"β TRL-style logging at step {step}") | |
| # Finish experiment (like TRL does) | |
| trackio.finish() | |
| print("β TRL-style finish") | |
| except Exception as e: | |
| print(f"β TRL trainer simulation failed: {e}") | |
| return False | |
| return True | |
| except Exception as e: | |
| print(f"β TRL trainer simulation test failed: {e}") | |
| return False | |
| def test_error_handling(): | |
| """Test error handling and fallbacks""" | |
| print("\nπ‘οΈ Testing Error Handling...") | |
| try: | |
| import trackio | |
| # Test 1: Graceful handling of missing monitor | |
| try: | |
| # This should not crash even if monitor is not available | |
| trackio.log({'test': 1.0}) | |
| print("β Graceful handling of logging without monitor") | |
| except Exception as e: | |
| print(f"β οΈ Logging without monitor: {e}") | |
| # This is acceptable - just a warning | |
| # Test 2: Config update with invalid data | |
| try: | |
| config = trackio.config | |
| config.update(invalid_param=None) | |
| print("β Config update with invalid data handled gracefully") | |
| except Exception as e: | |
| print(f"β Config update with invalid data failed: {e}") | |
| return False | |
| return True | |
| except Exception as e: | |
| print(f"β Error handling test failed: {e}") | |
| return False | |
| def test_dict_style_access(): | |
| """Test dictionary-style access to TrackioConfig""" | |
| print("\nπ Testing Dictionary-Style Access...") | |
| try: | |
| import trackio | |
| config = trackio.config | |
| # Test 1: Dictionary-style assignment | |
| try: | |
| config['test_key'] = 'test_value' | |
| print(f"β Dictionary assignment: test_key={config['test_key']}") | |
| except Exception as e: | |
| print(f"β Dictionary assignment failed: {e}") | |
| return False | |
| # Test 2: Dictionary-style access | |
| try: | |
| value = config['test_key'] | |
| print(f"β Dictionary access: {value}") | |
| except Exception as e: | |
| print(f"β Dictionary access failed: {e}") | |
| return False | |
| # Test 3: Contains check | |
| try: | |
| has_key = 'test_key' in config | |
| print(f"β Contains check: {'test_key' in config}") | |
| except Exception as e: | |
| print(f"β Contains check failed: {e}") | |
| return False | |
| # Test 4: Get method | |
| try: | |
| value = config.get('test_key', 'default') | |
| default_value = config.get('nonexistent', 'default') | |
| print(f"β Get method: {value}, default: {default_value}") | |
| except Exception as e: | |
| print(f"β Get method failed: {e}") | |
| return False | |
| # Test 5: TRL-style usage | |
| try: | |
| config['allow_val_change'] = True | |
| config['report_to'] = 'trackio' | |
| print(f"β TRL-style config: allow_val_change={config['allow_val_change']}, report_to={config['report_to']}") | |
| except Exception as e: | |
| print(f"β TRL-style config failed: {e}") | |
| return False | |
| return True | |
| except Exception as e: | |
| print(f"β Dictionary-style access test failed: {e}") | |
| return False | |
| def main(): | |
| """Run comprehensive TRL compatibility tests""" | |
| print("π§ͺ Comprehensive TRL Compatibility Test") | |
| print("=" * 50) | |
| tests = [ | |
| ("Core Interface", test_core_interface), | |
| ("Init Functionality", test_init_functionality), | |
| ("Log Functionality", test_log_functionality), | |
| ("Config Update", test_config_update), | |
| ("Finish Functionality", test_finish_functionality), | |
| ("TRL Trainer Simulation", test_trl_trainer_simulation), | |
| ("Error Handling", test_error_handling), | |
| ("Dictionary-Style Access", test_dict_style_access), | |
| ] | |
| results = [] | |
| for test_name, test_func in tests: | |
| print(f"\n{'='*20} {test_name} {'='*20}") | |
| try: | |
| result = test_func() | |
| results.append((test_name, result)) | |
| except Exception as e: | |
| print(f"β {test_name} crashed: {e}") | |
| results.append((test_name, False)) | |
| # Summary | |
| print("\n" + "=" * 50) | |
| print("π TRL Compatibility Test Results") | |
| print("=" * 50) | |
| passed = 0 | |
| total = len(results) | |
| for test_name, result in results: | |
| status = "β PASSED" if result else "β FAILED" | |
| print(f"{status}: {test_name}") | |
| if result: | |
| passed += 1 | |
| print(f"\nπ― Overall Results: {passed}/{total} tests passed") | |
| if passed == total: | |
| print("\nπ ALL TESTS PASSED! TRL compatibility is complete.") | |
| return True | |
| else: | |
| print(f"\nβ οΈ {total - passed} test(s) failed. Please review the implementation.") | |
| return False | |
| if __name__ == "__main__": | |
| success = main() | |
| sys.exit(0 if success else 1) |