|
""" |
|
Tests for FinRL Agent |
|
|
|
This module contains comprehensive tests for the FinRL agent functionality. |
|
""" |
|
|
|
import pytest |
|
import pandas as pd |
|
import numpy as np |
|
import yaml |
|
import tempfile |
|
import os |
|
from unittest.mock import Mock, patch |
|
|
|
|
|
import sys |
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
|
from agentic_ai_system.finrl_agent import ( |
|
FinRLAgent, |
|
FinRLConfig, |
|
TradingEnvironment |
|
) |
|
|
|
|
|
class TestFinRLConfig: |
|
"""Test FinRL configuration""" |
|
|
|
def test_default_config(self): |
|
"""Test default configuration values""" |
|
config = FinRLConfig() |
|
|
|
assert config.algorithm == "PPO" |
|
assert config.learning_rate == 0.0003 |
|
assert config.batch_size == 64 |
|
assert config.gamma == 0.99 |
|
|
|
def test_custom_config(self): |
|
"""Test custom configuration values""" |
|
config = FinRLConfig( |
|
algorithm="A2C", |
|
learning_rate=0.001, |
|
batch_size=128 |
|
) |
|
|
|
assert config.algorithm == "A2C" |
|
assert config.learning_rate == 0.001 |
|
assert config.batch_size == 128 |
|
|
|
|
|
class TestTradingEnvironment: |
|
"""Test trading environment""" |
|
|
|
@pytest.fixture |
|
def sample_data(self): |
|
"""Create sample market data""" |
|
dates = pd.date_range('2024-01-01', periods=100, freq='1h') |
|
data = pd.DataFrame({ |
|
'open': np.random.uniform(100, 200, 100), |
|
'high': np.random.uniform(100, 200, 100), |
|
'low': np.random.uniform(100, 200, 100), |
|
'close': np.random.uniform(100, 200, 100), |
|
'volume': np.random.uniform(1000, 10000, 100), |
|
'sma_20': np.random.uniform(100, 200, 100), |
|
'sma_50': np.random.uniform(100, 200, 100), |
|
'rsi': np.random.uniform(0, 100, 100), |
|
'bb_upper': np.random.uniform(100, 200, 100), |
|
'bb_lower': np.random.uniform(100, 200, 100), |
|
'macd': np.random.uniform(-10, 10, 100) |
|
}, index=dates) |
|
return data |
|
|
|
def test_environment_initialization(self, sample_data): |
|
"""Test environment initialization""" |
|
config = {'trading': {'symbol': 'AAPL'}} |
|
env = TradingEnvironment(sample_data, config) |
|
|
|
assert env.initial_balance == 100000 |
|
assert env.transaction_fee == 0.001 |
|
assert env.max_position == 100 |
|
assert env.action_space.n == 3 |
|
assert len(env.observation_space.shape) == 1 |
|
|
|
def test_environment_reset(self, sample_data): |
|
"""Test environment reset""" |
|
config = {'trading': {'symbol': 'AAPL'}} |
|
env = TradingEnvironment(sample_data, config) |
|
obs, info = env.reset() |
|
|
|
assert env.current_step == 0 |
|
assert env.balance == env.initial_balance |
|
assert env.position == 0 |
|
assert env.portfolio_value == env.initial_balance |
|
assert isinstance(obs, np.ndarray) |
|
assert isinstance(info, dict) |
|
|
|
def test_environment_step(self, sample_data): |
|
"""Test environment step""" |
|
config = {'trading': {'symbol': 'AAPL'}} |
|
env = TradingEnvironment(sample_data, config) |
|
obs, info = env.reset() |
|
|
|
|
|
obs, reward, done, truncated, info = env.step(1) |
|
|
|
assert isinstance(obs, np.ndarray) |
|
assert isinstance(reward, float) |
|
assert isinstance(done, bool) |
|
assert isinstance(truncated, bool) |
|
assert isinstance(info, dict) |
|
assert env.current_step == 1 |
|
|
|
def test_buy_action(self, sample_data): |
|
"""Test buy action""" |
|
config = {'trading': {'symbol': 'AAPL'}} |
|
env = TradingEnvironment(sample_data, config, initial_balance=10000) |
|
obs, info = env.reset() |
|
|
|
initial_balance = env.balance |
|
initial_position = env.position |
|
|
|
|
|
obs, reward, done, truncated, info = env.step(2) |
|
|
|
assert env.position > initial_position |
|
assert env.balance < initial_balance |
|
|
|
def test_sell_action(self, sample_data): |
|
"""Test sell action""" |
|
config = {'trading': {'symbol': 'AAPL'}} |
|
env = TradingEnvironment(sample_data, config, initial_balance=10000) |
|
obs, info = env.reset() |
|
|
|
|
|
obs, reward, done, truncated, info = env.step(2) |
|
initial_position = env.position |
|
initial_balance = env.balance |
|
|
|
|
|
obs, reward, done, truncated, info = env.step(0) |
|
|
|
assert env.position < initial_position |
|
assert env.balance > initial_balance |
|
|
|
def test_portfolio_value_calculation(self, sample_data): |
|
"""Test portfolio value calculation""" |
|
config = {'trading': {'symbol': 'AAPL'}} |
|
env = TradingEnvironment(sample_data, config) |
|
obs, info = env.reset() |
|
|
|
|
|
obs, reward, done, truncated, info = env.step(2) |
|
|
|
|
|
current_price = sample_data.iloc[env.current_step]['close'] |
|
expected_value = env.balance + (env.position * current_price) |
|
|
|
assert abs(env.portfolio_value - expected_value) < 5000.0 |
|
|
|
|
|
class TestFinRLAgent: |
|
"""Test FinRL agent""" |
|
|
|
@pytest.fixture |
|
def sample_data(self): |
|
"""Create sample market data""" |
|
dates = pd.date_range('2024-01-01', periods=100, freq='1h') |
|
data = pd.DataFrame({ |
|
'open': np.random.uniform(100, 200, 100), |
|
'high': np.random.uniform(100, 200, 100), |
|
'low': np.random.uniform(100, 200, 100), |
|
'close': np.random.uniform(100, 200, 100), |
|
'volume': np.random.uniform(1000, 10000, 100) |
|
}, index=dates) |
|
return data |
|
|
|
@pytest.fixture |
|
def finrl_config(self): |
|
"""Create FinRL configuration""" |
|
return FinRLConfig( |
|
algorithm="PPO", |
|
learning_rate=0.0003, |
|
batch_size=32 |
|
) |
|
|
|
def test_agent_initialization(self, finrl_config): |
|
"""Test agent initialization""" |
|
agent = FinRLAgent(finrl_config) |
|
|
|
assert agent.config == finrl_config |
|
assert agent.model is None |
|
assert agent.env is None |
|
|
|
def test_prepare_data(self, finrl_config, sample_data): |
|
"""Test data preparation""" |
|
agent = FinRLAgent(finrl_config) |
|
prepared_data = agent.prepare_data(sample_data) |
|
|
|
|
|
assert 'sma_20' in prepared_data.columns |
|
assert 'sma_50' in prepared_data.columns |
|
assert 'rsi' in prepared_data.columns |
|
assert 'bb_upper' in prepared_data.columns |
|
assert 'bb_lower' in prepared_data.columns |
|
assert 'macd' in prepared_data.columns |
|
|
|
|
|
assert not prepared_data.isnull().any().any() |
|
|
|
def test_create_environment(self, finrl_config, sample_data): |
|
"""Test environment creation""" |
|
agent = FinRLAgent(finrl_config) |
|
config = {'trading': {'symbol': 'AAPL'}} |
|
env = agent.create_environment(sample_data, config) |
|
|
|
assert isinstance(env, TradingEnvironment) |
|
assert len(env.data) == len(sample_data) |
|
|
|
def test_technical_indicators_calculation(self, finrl_config): |
|
"""Test technical indicators calculation""" |
|
agent = FinRLAgent(finrl_config) |
|
|
|
|
|
prices = pd.Series([100, 101, 99, 102, 98, 103, 97, 104, 96, 105]) |
|
rsi = agent._calculate_rsi(prices, period=3) |
|
assert len(rsi) == len(prices) |
|
assert not rsi.isnull().all() |
|
|
|
|
|
bb_upper, bb_lower = agent._calculate_bollinger_bands(prices, period=3) |
|
assert len(bb_upper) == len(prices) |
|
assert len(bb_lower) == len(prices) |
|
|
|
valid_mask = ~(bb_upper.isna() | bb_lower.isna()) |
|
if valid_mask.any(): |
|
assert (bb_upper[valid_mask] >= bb_lower[valid_mask]).all() |
|
|
|
|
|
macd = agent._calculate_macd(prices) |
|
assert len(macd) == len(prices) |
|
|
|
@pytest.mark.slow |
|
@patch('agentic_ai_system.finrl_agent.PPO') |
|
def test_training_ppo(self, mock_ppo, finrl_config, sample_data): |
|
"""Test PPO training""" |
|
|
|
mock_model = Mock() |
|
mock_ppo.return_value = mock_model |
|
|
|
agent = FinRLAgent(finrl_config) |
|
config = {'trading': {'symbol': 'AAPL'}} |
|
result = agent.train(sample_data, config, total_timesteps=5) |
|
|
|
assert result['algorithm'] == 'PPO' |
|
assert result['total_timesteps'] == 5 |
|
assert result['success'] == True |
|
mock_model.learn.assert_called_once() |
|
|
|
@pytest.mark.slow |
|
@patch('agentic_ai_system.finrl_agent.A2C') |
|
def test_training_a2c(self, mock_a2c): |
|
"""Test A2C training""" |
|
config = FinRLConfig(algorithm="A2C") |
|
mock_model = Mock() |
|
mock_a2c.return_value = mock_model |
|
|
|
agent = FinRLAgent(config) |
|
sample_data = pd.DataFrame({ |
|
'open': [100, 101, 102], |
|
'high': [101, 102, 103], |
|
'low': [99, 100, 101], |
|
'close': [100, 101, 102], |
|
'volume': [1000, 1100, 1200] |
|
}) |
|
|
|
trading_config = {'trading': {'symbol': 'AAPL'}} |
|
result = agent.train(sample_data, trading_config, total_timesteps=5) |
|
|
|
assert result['algorithm'] == 'A2C' |
|
assert result['success'] == True |
|
mock_model.learn.assert_called_once() |
|
|
|
def test_invalid_algorithm(self): |
|
"""Test invalid algorithm handling""" |
|
config = FinRLConfig(algorithm="INVALID") |
|
agent = FinRLAgent(config) |
|
sample_data = pd.DataFrame({ |
|
'open': [100, 101, 102], |
|
'high': [101, 102, 103], |
|
'low': [99, 100, 101], |
|
'close': [100, 101, 102], |
|
'volume': [1000, 1100, 1200] |
|
}) |
|
|
|
trading_config = {'trading': {'symbol': 'AAPL'}} |
|
result = agent.train(sample_data, trading_config, total_timesteps=100) |
|
|
|
|
|
assert result['success'] == False |
|
assert 'error' in result |
|
|
|
def test_predict_without_training(self, finrl_config, sample_data): |
|
"""Test prediction without training""" |
|
agent = FinRLAgent(finrl_config) |
|
|
|
config = {'trading': {'symbol': 'AAPL'}} |
|
result = agent.predict(sample_data, config) |
|
|
|
|
|
assert result['success'] == False |
|
assert 'error' in result |
|
|
|
def test_evaluate_without_training(self, finrl_config, sample_data): |
|
"""Test evaluation without training""" |
|
agent = FinRLAgent(finrl_config) |
|
|
|
config = {'trading': {'symbol': 'AAPL'}} |
|
result = agent.evaluate(sample_data, config) |
|
|
|
|
|
assert result['success'] == False |
|
assert 'error' in result |
|
|
|
@patch('agentic_ai_system.finrl_agent.PPO') |
|
def test_save_and_load_model(self, mock_ppo, finrl_config, sample_data): |
|
"""Test model saving and loading""" |
|
|
|
mock_model = Mock() |
|
mock_ppo.return_value = mock_model |
|
mock_ppo.load.return_value = mock_model |
|
|
|
agent = FinRLAgent(finrl_config) |
|
|
|
|
|
config = {'trading': {'symbol': 'AAPL'}} |
|
agent.train(sample_data, config, total_timesteps=100) |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix='.zip', delete=False) as tmp_file: |
|
result = agent.save_model(tmp_file.name) |
|
assert result == True |
|
|
|
mock_model.save.assert_any_call(tmp_file.name) |
|
|
|
|
|
result = agent.load_model(tmp_file.name, config) |
|
assert result == True |
|
mock_ppo.load.assert_called_once_with(tmp_file.name) |
|
|
|
|
|
os.unlink(tmp_file.name) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
pytest.main([__file__]) |