|
import unittest |
|
import os |
|
import sys |
|
import json |
|
from unittest.mock import patch, MagicMock |
|
|
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
|
|
|
from model_config import ModelConfigManager |
|
|
|
class MockDB: |
|
"""Mock database for testing.""" |
|
|
|
def __init__(self): |
|
self.configs = {} |
|
|
|
def get_config(self, config_id): |
|
return self.configs.get(config_id) |
|
|
|
def add_config(self, config_id, config): |
|
self.configs[config_id] = config |
|
return config_id |
|
|
|
class TestModelConfigManager(unittest.TestCase): |
|
"""Test the ModelConfigManager class.""" |
|
|
|
def setUp(self): |
|
"""Set up test environment.""" |
|
self.db = MockDB() |
|
self.config_dir = "test_model_configs" |
|
|
|
|
|
os.makedirs(self.config_dir, exist_ok=True) |
|
|
|
|
|
self.manager = ModelConfigManager(self.db) |
|
self.manager.config_dir = self.config_dir |
|
|
|
def tearDown(self): |
|
"""Clean up after tests.""" |
|
|
|
for filename in os.listdir(self.config_dir): |
|
file_path = os.path.join(self.config_dir, filename) |
|
if os.path.isfile(file_path): |
|
os.unlink(file_path) |
|
|
|
|
|
os.rmdir(self.config_dir) |
|
|
|
def test_initialize_default_configs(self): |
|
"""Test _initialize_default_configs method.""" |
|
|
|
self.manager._initialize_default_configs() |
|
|
|
|
|
for model_type in self.manager.default_configs: |
|
config_path = os.path.join(self.config_dir, f"{model_type}.json") |
|
self.assertTrue(os.path.exists(config_path)) |
|
|
|
|
|
with open(config_path, "r") as f: |
|
config = json.load(f) |
|
|
|
self.assertEqual(config["name"], self.manager.default_configs[model_type]["name"]) |
|
self.assertEqual(config["description"], self.manager.default_configs[model_type]["description"]) |
|
self.assertEqual(config["parameters"], self.manager.default_configs[model_type]["parameters"]) |
|
|
|
def test_get_available_configs(self): |
|
"""Test get_available_configs method.""" |
|
|
|
test_configs = { |
|
"test1": { |
|
"name": "Test Config 1", |
|
"description": "Test description 1", |
|
"parameters": {"temperature": 0.7} |
|
}, |
|
"test2": { |
|
"name": "Test Config 2", |
|
"description": "Test description 2", |
|
"parameters": {"temperature": 0.8} |
|
} |
|
} |
|
|
|
for config_id, config in test_configs.items(): |
|
config_path = os.path.join(self.config_dir, f"{config_id}.json") |
|
with open(config_path, "w") as f: |
|
json.dump(config, f) |
|
|
|
|
|
configs = self.manager.get_available_configs() |
|
|
|
|
|
self.assertEqual(len(configs), 2) |
|
|
|
|
|
config_ids = [config["id"] for config in configs] |
|
self.assertIn("test1", config_ids) |
|
self.assertIn("test2", config_ids) |
|
|
|
|
|
for config in configs: |
|
test_config = test_configs[config["id"]] |
|
self.assertEqual(config["name"], test_config["name"]) |
|
self.assertEqual(config["description"], test_config["description"]) |
|
self.assertEqual(config["parameters"], test_config["parameters"]) |
|
|
|
def test_get_config(self): |
|
"""Test get_config method.""" |
|
|
|
config = { |
|
"name": "Test Config", |
|
"description": "Test description", |
|
"parameters": {"temperature": 0.7} |
|
} |
|
|
|
config_path = os.path.join(self.config_dir, "test.json") |
|
with open(config_path, "w") as f: |
|
json.dump(config, f) |
|
|
|
|
|
result = self.manager.get_config("test") |
|
|
|
|
|
self.assertIsNotNone(result) |
|
self.assertEqual(result["id"], "test") |
|
self.assertEqual(result["name"], config["name"]) |
|
self.assertEqual(result["description"], config["description"]) |
|
self.assertEqual(result["parameters"], config["parameters"]) |
|
|
|
|
|
result = self.manager.get_config("nonexistent") |
|
self.assertIsNone(result) |
|
|
|
def test_add_config(self): |
|
"""Test add_config method.""" |
|
|
|
name = "Test Config" |
|
description = "Test description" |
|
parameters = {"temperature": 0.7, "top_k": 50} |
|
|
|
config_id = self.manager.add_config(name, description, parameters) |
|
|
|
|
|
self.assertIsNotNone(config_id) |
|
self.assertEqual(config_id, "test_config") |
|
|
|
|
|
config_path = os.path.join(self.config_dir, f"{config_id}.json") |
|
self.assertTrue(os.path.exists(config_path)) |
|
|
|
|
|
with open(config_path, "r") as f: |
|
config = json.load(f) |
|
|
|
self.assertEqual(config["name"], name) |
|
self.assertEqual(config["description"], description) |
|
self.assertEqual(config["parameters"], parameters) |
|
|
|
def test_update_config(self): |
|
"""Test update_config method.""" |
|
|
|
config = { |
|
"name": "Test Config", |
|
"description": "Test description", |
|
"parameters": {"temperature": 0.7} |
|
} |
|
|
|
config_path = os.path.join(self.config_dir, "test.json") |
|
with open(config_path, "w") as f: |
|
json.dump(config, f) |
|
|
|
|
|
new_name = "Updated Config" |
|
new_description = "Updated description" |
|
new_parameters = {"temperature": 0.8, "top_k": 60} |
|
|
|
success = self.manager.update_config("test", new_name, new_description, new_parameters) |
|
|
|
|
|
self.assertTrue(success) |
|
|
|
|
|
with open(config_path, "r") as f: |
|
updated_config = json.load(f) |
|
|
|
self.assertEqual(updated_config["name"], new_name) |
|
self.assertEqual(updated_config["description"], new_description) |
|
self.assertEqual(updated_config["parameters"], new_parameters) |
|
|
|
|
|
success = self.manager.update_config("nonexistent", "New Name", "New Description", {}) |
|
self.assertFalse(success) |
|
|
|
def test_delete_config(self): |
|
"""Test delete_config method.""" |
|
|
|
config = { |
|
"name": "Test Config", |
|
"description": "Test description", |
|
"parameters": {"temperature": 0.7} |
|
} |
|
|
|
config_path = os.path.join(self.config_dir, "test.json") |
|
with open(config_path, "w") as f: |
|
json.dump(config, f) |
|
|
|
|
|
success = self.manager.delete_config("test") |
|
|
|
|
|
self.assertTrue(success) |
|
self.assertFalse(os.path.exists(config_path)) |
|
|
|
|
|
success = self.manager.delete_config("nonexistent") |
|
self.assertFalse(success) |
|
|
|
|
|
for model_type in self.manager.default_configs: |
|
config_path = os.path.join(self.config_dir, f"{model_type}.json") |
|
with open(config_path, "w") as f: |
|
json.dump(self.manager.default_configs[model_type], f) |
|
|
|
success = self.manager.delete_config(model_type) |
|
self.assertFalse(success) |
|
self.assertTrue(os.path.exists(config_path)) |
|
|
|
def test_apply_config_to_model_params(self): |
|
"""Test apply_config_to_model_params method.""" |
|
|
|
config = { |
|
"name": "Test Config", |
|
"description": "Test description", |
|
"parameters": { |
|
"temperature": 0.7, |
|
"top_k": 50, |
|
"top_p": 0.9 |
|
} |
|
} |
|
|
|
config_path = os.path.join(self.config_dir, "test.json") |
|
with open(config_path, "w") as f: |
|
json.dump(config, f) |
|
|
|
|
|
model_params = { |
|
"temperature": 0.5, |
|
"max_length": 100 |
|
} |
|
|
|
result = self.manager.apply_config_to_model_params(model_params, "test") |
|
|
|
|
|
self.assertEqual(result["temperature"], 0.7) |
|
self.assertEqual(result["top_k"], 50) |
|
self.assertEqual(result["top_p"], 0.9) |
|
self.assertEqual(result["max_length"], 100) |
|
|
|
|
|
result = self.manager.apply_config_to_model_params(model_params, "nonexistent") |
|
self.assertEqual(result, model_params) |
|
|
|
if __name__ == '__main__': |
|
unittest.main() |
|
|