surf-spot-finder / tests /unit /agents /test_unit_smolagents.py
David de la Iglesia Castro
7 make tools configurable (#21)
f53eca9 unverified
raw
history blame
2.61 kB
import os
import pytest
from unittest.mock import patch, MagicMock
import contextlib
from surf_spot_finder.agents.smolagents import run_smolagent
@pytest.fixture
def common_patches():
litellm_model_mock = MagicMock()
code_agent_mock = MagicMock()
patch_context = contextlib.ExitStack()
patch_context.enter_context(
patch("surf_spot_finder.agents.smolagents.CodeAgent", code_agent_mock)
)
patch_context.enter_context(
patch("surf_spot_finder.agents.smolagents.LiteLLMModel", litellm_model_mock)
)
yield patch_context, litellm_model_mock, code_agent_mock
patch_context.close()
def test_run_smolagent_with_api_key_var(common_patches):
patch_context, litellm_model_mock, code_agent_mock = common_patches
with patch_context, patch.dict(os.environ, {"TEST_API_KEY": "test-key-12345"}):
run_smolagent("openai/gpt-4", "Test prompt", api_key_var="TEST_API_KEY")
litellm_model_mock.assert_called()
model_call_kwargs = litellm_model_mock.call_args[1]
assert model_call_kwargs["model_id"] == "openai/gpt-4"
assert model_call_kwargs["api_key"] == "test-key-12345"
assert model_call_kwargs["api_base"] is None
code_agent_mock.assert_called_once()
code_agent_mock.return_value.run.assert_called_once_with("Test prompt")
def test_run_smolagent_with_custom_api_base(common_patches):
patch_context, litellm_model_mock, *_ = common_patches
with patch_context, patch.dict(os.environ, {"TEST_API_KEY": "test-key-12345"}):
run_smolagent(
"anthropic/claude-3-sonnet",
"Test prompt",
api_key_var="TEST_API_KEY",
api_base="https://custom-api.example.com",
)
last_call = litellm_model_mock.call_args_list[-1]
assert last_call[1]["model_id"] == "anthropic/claude-3-sonnet"
assert last_call[1]["api_key"] == "test-key-12345"
assert last_call[1]["api_base"] == "https://custom-api.example.com"
def test_run_smolagent_without_api_key(common_patches):
patch_context, litellm_model_mock, *_ = common_patches
with patch_context:
run_smolagent("ollama_chat/deepseek-r1", "Test prompt")
last_call = litellm_model_mock.call_args_list[-1]
assert last_call[1]["model_id"] == "ollama_chat/deepseek-r1"
assert last_call[1]["api_key"] is None
def test_run_smolagent_environment_error():
with patch.dict(os.environ, {}, clear=True):
with pytest.raises(KeyError, match="MISSING_KEY"):
run_smolagent("test-model", "Test prompt", api_key_var="MISSING_KEY")