Edwin Salguero
commited on
Commit
·
2c67d05
1
Parent(s):
859af74
Add FinRL integration with comprehensive RL trading agent
Browse files- Add FinRL agent with support for PPO, A2C, DDPG, and TD3 algorithms
- Create custom trading environment compatible with Gymnasium
- Implement technical indicators integration (RSI, Bollinger Bands, MACD)
- Add comprehensive configuration system for FinRL parameters
- Create demo script with training, evaluation, and visualization
- Add comprehensive test suite for FinRL functionality
- Update requirements.txt with FinRL dependencies
- Update README with detailed FinRL documentation
- Create necessary directories for models, logs, and plots
- README.md +133 -1
- agentic_ai_system/finrl_agent.py +447 -0
- config.yaml +27 -0
- finrl_demo.py +294 -0
- requirements.txt +5 -0
- tests/test_finrl_agent.py +373 -0
README.md
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
# Algorithmic Trading System
|
2 |
|
3 |
-
A comprehensive algorithmic trading system with synthetic data generation, comprehensive logging,
|
4 |
|
5 |
## Features
|
6 |
|
@@ -10,6 +10,15 @@ A comprehensive algorithmic trading system with synthetic data generation, compr
|
|
10 |
- **Risk Management**: Position sizing and drawdown limits
|
11 |
- **Order Execution**: Simulated broker integration with realistic execution delays
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
### Synthetic Data Generation
|
14 |
- **Realistic Market Data**: Generate OHLCV data using geometric Brownian motion
|
15 |
- **Multiple Frequencies**: Support for 1min, 5min, 1H, and 1D data
|
@@ -226,6 +235,129 @@ logger.warning("High volatility detected")
|
|
226 |
logger.error("Order execution failed", exc_info=True)
|
227 |
```
|
228 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
## Testing
|
230 |
|
231 |
### Test Structure
|
|
|
1 |
# Algorithmic Trading System
|
2 |
|
3 |
+
A comprehensive algorithmic trading system with synthetic data generation, comprehensive logging, extensive testing capabilities, and FinRL reinforcement learning integration.
|
4 |
|
5 |
## Features
|
6 |
|
|
|
10 |
- **Risk Management**: Position sizing and drawdown limits
|
11 |
- **Order Execution**: Simulated broker integration with realistic execution delays
|
12 |
|
13 |
+
### FinRL Reinforcement Learning
|
14 |
+
- **Multiple RL Algorithms**: Support for PPO, A2C, DDPG, and TD3
|
15 |
+
- **Custom Trading Environment**: Gymnasium-compatible environment for RL training
|
16 |
+
- **Technical Indicators Integration**: Automatic calculation and inclusion of technical indicators
|
17 |
+
- **Portfolio Management**: Realistic portfolio simulation with transaction costs
|
18 |
+
- **Model Persistence**: Save and load trained models for inference
|
19 |
+
- **TensorBoard Integration**: Training progress visualization and monitoring
|
20 |
+
- **Comprehensive Evaluation**: Performance metrics including Sharpe ratio and total returns
|
21 |
+
|
22 |
### Synthetic Data Generation
|
23 |
- **Realistic Market Data**: Generate OHLCV data using geometric Brownian motion
|
24 |
- **Multiple Frequencies**: Support for 1min, 5min, 1H, and 1D data
|
|
|
235 |
logger.error("Order execution failed", exc_info=True)
|
236 |
```
|
237 |
|
238 |
+
## FinRL Integration
|
239 |
+
|
240 |
+
### Overview
|
241 |
+
The system now includes FinRL (Financial Reinforcement Learning) integration, providing state-of-the-art reinforcement learning capabilities for algorithmic trading. The FinRL agent can learn optimal trading strategies through interaction with a simulated market environment.
|
242 |
+
|
243 |
+
### Supported Algorithms
|
244 |
+
- **PPO (Proximal Policy Optimization)**: Stable policy gradient method
|
245 |
+
- **A2C (Advantage Actor-Critic)**: Actor-critic method with advantage estimation
|
246 |
+
- **DDPG (Deep Deterministic Policy Gradient)**: Continuous action space algorithm
|
247 |
+
- **TD3 (Twin Delayed DDPG)**: Improved version of DDPG with twin critics
|
248 |
+
|
249 |
+
### Trading Environment
|
250 |
+
The custom trading environment provides:
|
251 |
+
- **Action Space**: Discrete actions (0=Buy, 1=Hold, 2=Sell)
|
252 |
+
- **Observation Space**: OHLCV data + technical indicators + portfolio state
|
253 |
+
- **Reward Function**: Portfolio return-based rewards
|
254 |
+
- **Transaction Costs**: Realistic trading fees and slippage
|
255 |
+
- **Position Limits**: Maximum position constraints
|
256 |
+
|
257 |
+
### Usage Examples
|
258 |
+
|
259 |
+
#### Basic FinRL Training
|
260 |
+
```python
|
261 |
+
from agentic_ai_system.finrl_agent import FinRLAgent, FinRLConfig
|
262 |
+
import pandas as pd
|
263 |
+
|
264 |
+
# Create configuration
|
265 |
+
config = FinRLConfig(
|
266 |
+
algorithm="PPO",
|
267 |
+
learning_rate=0.0003,
|
268 |
+
batch_size=64,
|
269 |
+
total_timesteps=100000
|
270 |
+
)
|
271 |
+
|
272 |
+
# Initialize agent
|
273 |
+
agent = FinRLAgent(config)
|
274 |
+
|
275 |
+
# Train the agent
|
276 |
+
training_result = agent.train(
|
277 |
+
data=market_data,
|
278 |
+
total_timesteps=100000,
|
279 |
+
eval_freq=10000
|
280 |
+
)
|
281 |
+
|
282 |
+
# Generate predictions
|
283 |
+
predictions = agent.predict(test_data)
|
284 |
+
|
285 |
+
# Evaluate performance
|
286 |
+
evaluation = agent.evaluate(test_data)
|
287 |
+
print(f"Total Return: {evaluation['total_return']:.2%}")
|
288 |
+
```
|
289 |
+
|
290 |
+
#### Using Configuration File
|
291 |
+
```python
|
292 |
+
from agentic_ai_system.finrl_agent import create_finrl_agent_from_config
|
293 |
+
|
294 |
+
# Create agent from config file
|
295 |
+
agent = create_finrl_agent_from_config('config.yaml')
|
296 |
+
|
297 |
+
# Train and evaluate
|
298 |
+
agent.train(market_data)
|
299 |
+
results = agent.evaluate(test_data)
|
300 |
+
```
|
301 |
+
|
302 |
+
#### Running FinRL Demo
|
303 |
+
```bash
|
304 |
+
# Run the complete FinRL demo
|
305 |
+
python finrl_demo.py
|
306 |
+
|
307 |
+
# This will:
|
308 |
+
# 1. Generate synthetic training and test data
|
309 |
+
# 2. Train a FinRL agent
|
310 |
+
# 3. Evaluate performance
|
311 |
+
# 4. Generate trading predictions
|
312 |
+
# 5. Create visualization plots
|
313 |
+
```
|
314 |
+
|
315 |
+
### Configuration
|
316 |
+
FinRL settings can be configured in `config.yaml`:
|
317 |
+
|
318 |
+
```yaml
|
319 |
+
finrl:
|
320 |
+
algorithm: 'PPO' # PPO, A2C, DDPG, TD3
|
321 |
+
learning_rate: 0.0003
|
322 |
+
batch_size: 64
|
323 |
+
buffer_size: 1000000
|
324 |
+
gamma: 0.99
|
325 |
+
tensorboard_log: 'logs/finrl_tensorboard'
|
326 |
+
training:
|
327 |
+
total_timesteps: 100000
|
328 |
+
eval_freq: 10000
|
329 |
+
save_best_model: true
|
330 |
+
model_save_path: 'models/finrl_best/'
|
331 |
+
inference:
|
332 |
+
use_trained_model: false
|
333 |
+
model_path: 'models/finrl_best/best_model'
|
334 |
+
```
|
335 |
+
|
336 |
+
### Model Management
|
337 |
+
```python
|
338 |
+
# Save trained model
|
339 |
+
agent.save_model('models/my_finrl_model')
|
340 |
+
|
341 |
+
# Load pre-trained model
|
342 |
+
agent.load_model('models/my_finrl_model')
|
343 |
+
|
344 |
+
# Continue training
|
345 |
+
agent.train(more_data, total_timesteps=50000)
|
346 |
+
```
|
347 |
+
|
348 |
+
### Performance Monitoring
|
349 |
+
- **TensorBoard Integration**: Monitor training progress
|
350 |
+
- **Evaluation Metrics**: Total return, Sharpe ratio, portfolio value
|
351 |
+
- **Trading Statistics**: Buy/sell signal analysis
|
352 |
+
- **Visualization**: Price charts with trading signals
|
353 |
+
|
354 |
+
### Advanced Features
|
355 |
+
- **Multi-timeframe Support**: Train on different data frequencies
|
356 |
+
- **Feature Engineering**: Automatic technical indicator calculation
|
357 |
+
- **Risk Management**: Built-in position and drawdown limits
|
358 |
+
- **Backtesting**: Comprehensive backtesting capabilities
|
359 |
+
- **Hyperparameter Tuning**: Easy configuration for different algorithms
|
360 |
+
|
361 |
## Testing
|
362 |
|
363 |
### Test Structure
|
agentic_ai_system/finrl_agent.py
ADDED
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
FinRL Agent for Algorithmic Trading
|
3 |
+
|
4 |
+
This module provides a FinRL-based reinforcement learning agent that can be integrated
|
5 |
+
with the existing algorithmic trading system. It supports various RL algorithms
|
6 |
+
including PPO, A2C, DDPG, and TD3.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import pandas as pd
|
11 |
+
import gymnasium as gym
|
12 |
+
from gymnasium import spaces
|
13 |
+
from stable_baselines3 import PPO, A2C, DDPG, TD3
|
14 |
+
from stable_baselines3.common.vec_env import DummyVecEnv
|
15 |
+
from stable_baselines3.common.callbacks import EvalCallback
|
16 |
+
import torch
|
17 |
+
import logging
|
18 |
+
from typing import Dict, List, Tuple, Optional, Any
|
19 |
+
from dataclasses import dataclass
|
20 |
+
import yaml
|
21 |
+
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
@dataclass
|
26 |
+
class FinRLConfig:
|
27 |
+
"""Configuration for FinRL agent"""
|
28 |
+
algorithm: str = "PPO" # PPO, A2C, DDPG, TD3
|
29 |
+
learning_rate: float = 0.0003
|
30 |
+
batch_size: int = 64
|
31 |
+
buffer_size: int = 1000000
|
32 |
+
learning_starts: int = 100
|
33 |
+
gamma: float = 0.99
|
34 |
+
tau: float = 0.005
|
35 |
+
train_freq: int = 1
|
36 |
+
gradient_steps: int = 1
|
37 |
+
target_update_interval: int = 1
|
38 |
+
exploration_fraction: float = 0.1
|
39 |
+
exploration_initial_eps: float = 1.0
|
40 |
+
exploration_final_eps: float = 0.05
|
41 |
+
max_grad_norm: float = 10.0
|
42 |
+
verbose: int = 1
|
43 |
+
tensorboard_log: str = "logs/finrl_tensorboard"
|
44 |
+
|
45 |
+
|
46 |
+
class TradingEnvironment(gym.Env):
|
47 |
+
"""
|
48 |
+
Custom trading environment for FinRL
|
49 |
+
|
50 |
+
This environment simulates a trading scenario where the agent can:
|
51 |
+
- Buy, sell, or hold positions
|
52 |
+
- Use technical indicators for decision making
|
53 |
+
- Manage portfolio value and risk
|
54 |
+
"""
|
55 |
+
|
56 |
+
def __init__(self, data: pd.DataFrame, initial_balance: float = 100000,
|
57 |
+
transaction_fee: float = 0.001, max_position: int = 100):
|
58 |
+
super().__init__()
|
59 |
+
|
60 |
+
self.data = data
|
61 |
+
self.initial_balance = initial_balance
|
62 |
+
self.transaction_fee = transaction_fee
|
63 |
+
self.max_position = max_position
|
64 |
+
|
65 |
+
# Reset state
|
66 |
+
self.reset()
|
67 |
+
|
68 |
+
# Define action space: [-1, 0, 1] for sell, hold, buy
|
69 |
+
self.action_space = spaces.Discrete(3)
|
70 |
+
|
71 |
+
# Define observation space
|
72 |
+
# Features: OHLCV + technical indicators + portfolio state
|
73 |
+
n_features = len(self._get_features(self.data.iloc[0]))
|
74 |
+
self.observation_space = spaces.Box(
|
75 |
+
low=-np.inf, high=np.inf, shape=(n_features,), dtype=np.float32
|
76 |
+
)
|
77 |
+
|
78 |
+
def _get_features(self, row: pd.Series) -> np.ndarray:
|
79 |
+
"""Extract features from market data row"""
|
80 |
+
features = []
|
81 |
+
|
82 |
+
# Price features
|
83 |
+
features.extend([
|
84 |
+
row['open'], row['high'], row['low'], row['close'], row['volume']
|
85 |
+
])
|
86 |
+
|
87 |
+
# Technical indicators (if available)
|
88 |
+
for indicator in ['sma_20', 'sma_50', 'rsi', 'bb_upper', 'bb_lower', 'macd']:
|
89 |
+
if indicator in row.index:
|
90 |
+
features.append(row[indicator])
|
91 |
+
else:
|
92 |
+
features.append(0.0)
|
93 |
+
|
94 |
+
# Portfolio state
|
95 |
+
features.extend([
|
96 |
+
self.balance,
|
97 |
+
self.position,
|
98 |
+
self.portfolio_value,
|
99 |
+
self.total_return
|
100 |
+
])
|
101 |
+
|
102 |
+
return np.array(features, dtype=np.float32)
|
103 |
+
|
104 |
+
def _calculate_portfolio_value(self) -> float:
|
105 |
+
"""Calculate current portfolio value"""
|
106 |
+
current_price = self.data.iloc[self.current_step]['close']
|
107 |
+
return self.balance + (self.position * current_price)
|
108 |
+
|
109 |
+
def _calculate_reward(self) -> float:
|
110 |
+
"""Calculate reward based on portfolio performance"""
|
111 |
+
current_value = self._calculate_portfolio_value()
|
112 |
+
previous_value = self.previous_portfolio_value
|
113 |
+
|
114 |
+
# Calculate return
|
115 |
+
if previous_value > 0:
|
116 |
+
return (current_value - previous_value) / previous_value
|
117 |
+
else:
|
118 |
+
return 0.0
|
119 |
+
|
120 |
+
def step(self, action: int) -> Tuple[np.ndarray, float, bool, bool, Dict]:
|
121 |
+
"""Execute one step in the environment"""
|
122 |
+
|
123 |
+
# Get current market data
|
124 |
+
current_data = self.data.iloc[self.current_step]
|
125 |
+
current_price = current_data['close']
|
126 |
+
|
127 |
+
# Execute action
|
128 |
+
if action == 0: # Sell
|
129 |
+
if self.position > 0:
|
130 |
+
shares_to_sell = min(self.position, self.max_position)
|
131 |
+
sell_value = shares_to_sell * current_price * (1 - self.transaction_fee)
|
132 |
+
self.balance += sell_value
|
133 |
+
self.position -= shares_to_sell
|
134 |
+
elif action == 2: # Buy
|
135 |
+
if self.balance > 0:
|
136 |
+
max_shares = min(
|
137 |
+
int(self.balance / current_price),
|
138 |
+
self.max_position - self.position
|
139 |
+
)
|
140 |
+
if max_shares > 0:
|
141 |
+
buy_value = max_shares * current_price * (1 + self.transaction_fee)
|
142 |
+
self.balance -= buy_value
|
143 |
+
self.position += max_shares
|
144 |
+
|
145 |
+
# Update portfolio value
|
146 |
+
self.previous_portfolio_value = self.portfolio_value
|
147 |
+
self.portfolio_value = self._calculate_portfolio_value()
|
148 |
+
self.total_return = (self.portfolio_value - self.initial_balance) / self.initial_balance
|
149 |
+
|
150 |
+
# Calculate reward
|
151 |
+
reward = self._calculate_reward()
|
152 |
+
|
153 |
+
# Move to next step
|
154 |
+
self.current_step += 1
|
155 |
+
|
156 |
+
# Check if episode is done
|
157 |
+
done = self.current_step >= len(self.data) - 1
|
158 |
+
|
159 |
+
# Get observation
|
160 |
+
if not done:
|
161 |
+
observation = self._get_features(self.data.iloc[self.current_step])
|
162 |
+
else:
|
163 |
+
# Use last available data for final observation
|
164 |
+
observation = self._get_features(self.data.iloc[-1])
|
165 |
+
|
166 |
+
info = {
|
167 |
+
'balance': self.balance,
|
168 |
+
'position': self.position,
|
169 |
+
'portfolio_value': self.portfolio_value,
|
170 |
+
'total_return': self.total_return,
|
171 |
+
'current_price': current_price
|
172 |
+
}
|
173 |
+
|
174 |
+
return observation, reward, done, False, info
|
175 |
+
|
176 |
+
def reset(self, seed: Optional[int] = None) -> Tuple[np.ndarray, Dict]:
|
177 |
+
"""Reset the environment"""
|
178 |
+
super().reset(seed=seed)
|
179 |
+
|
180 |
+
self.current_step = 0
|
181 |
+
self.balance = self.initial_balance
|
182 |
+
self.position = 0
|
183 |
+
self.portfolio_value = self.initial_balance
|
184 |
+
self.previous_portfolio_value = self.initial_balance
|
185 |
+
self.total_return = 0.0
|
186 |
+
|
187 |
+
observation = self._get_features(self.data.iloc[self.current_step])
|
188 |
+
info = {
|
189 |
+
'balance': self.balance,
|
190 |
+
'position': self.position,
|
191 |
+
'portfolio_value': self.portfolio_value,
|
192 |
+
'total_return': self.total_return
|
193 |
+
}
|
194 |
+
|
195 |
+
return observation, info
|
196 |
+
|
197 |
+
|
198 |
+
class FinRLAgent:
|
199 |
+
"""
|
200 |
+
FinRL-based reinforcement learning agent for algorithmic trading
|
201 |
+
"""
|
202 |
+
|
203 |
+
def __init__(self, config: FinRLConfig):
|
204 |
+
self.config = config
|
205 |
+
self.model = None
|
206 |
+
self.env = None
|
207 |
+
self.eval_env = None
|
208 |
+
self.callback = None
|
209 |
+
|
210 |
+
logger.info(f"Initializing FinRL agent with algorithm: {config.algorithm}")
|
211 |
+
|
212 |
+
def create_environment(self, data: pd.DataFrame, initial_balance: float = 100000) -> TradingEnvironment:
|
213 |
+
"""Create trading environment from market data"""
|
214 |
+
return TradingEnvironment(
|
215 |
+
data=data,
|
216 |
+
initial_balance=initial_balance,
|
217 |
+
transaction_fee=0.001,
|
218 |
+
max_position=100
|
219 |
+
)
|
220 |
+
|
221 |
+
def prepare_data(self, data: pd.DataFrame) -> pd.DataFrame:
|
222 |
+
"""Prepare data with technical indicators for FinRL"""
|
223 |
+
df = data.copy()
|
224 |
+
|
225 |
+
# Add technical indicators if not present
|
226 |
+
if 'sma_20' not in df.columns:
|
227 |
+
df['sma_20'] = df['close'].rolling(window=20).mean()
|
228 |
+
if 'sma_50' not in df.columns:
|
229 |
+
df['sma_50'] = df['close'].rolling(window=50).mean()
|
230 |
+
if 'rsi' not in df.columns:
|
231 |
+
df['rsi'] = self._calculate_rsi(df['close'])
|
232 |
+
if 'bb_upper' not in df.columns or 'bb_lower' not in df.columns:
|
233 |
+
bb_upper, bb_lower = self._calculate_bollinger_bands(df['close'])
|
234 |
+
df['bb_upper'] = bb_upper
|
235 |
+
df['bb_lower'] = bb_lower
|
236 |
+
if 'macd' not in df.columns:
|
237 |
+
df['macd'] = self._calculate_macd(df['close'])
|
238 |
+
|
239 |
+
# Fill NaN values
|
240 |
+
df = df.fillna(method='bfill').fillna(0)
|
241 |
+
|
242 |
+
return df
|
243 |
+
|
244 |
+
def _calculate_rsi(self, prices: pd.Series, period: int = 14) -> pd.Series:
|
245 |
+
"""Calculate RSI indicator"""
|
246 |
+
delta = prices.diff()
|
247 |
+
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
|
248 |
+
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
|
249 |
+
rs = gain / loss
|
250 |
+
rsi = 100 - (100 / (1 + rs))
|
251 |
+
return rsi
|
252 |
+
|
253 |
+
def _calculate_bollinger_bands(self, prices: pd.Series, period: int = 20, std_dev: int = 2) -> Tuple[pd.Series, pd.Series]:
|
254 |
+
"""Calculate Bollinger Bands"""
|
255 |
+
sma = prices.rolling(window=period).mean()
|
256 |
+
std = prices.rolling(window=period).std()
|
257 |
+
upper_band = sma + (std * std_dev)
|
258 |
+
lower_band = sma - (std * std_dev)
|
259 |
+
return upper_band, lower_band
|
260 |
+
|
261 |
+
def _calculate_macd(self, prices: pd.Series, fast: int = 12, slow: int = 26, signal: int = 9) -> pd.Series:
|
262 |
+
"""Calculate MACD indicator"""
|
263 |
+
ema_fast = prices.ewm(span=fast).mean()
|
264 |
+
ema_slow = prices.ewm(span=slow).mean()
|
265 |
+
macd_line = ema_fast - ema_slow
|
266 |
+
return macd_line
|
267 |
+
|
268 |
+
def train(self, data: pd.DataFrame, total_timesteps: int = 100000,
|
269 |
+
eval_freq: int = 10000, eval_data: Optional[pd.DataFrame] = None) -> Dict[str, Any]:
|
270 |
+
"""Train the FinRL agent"""
|
271 |
+
|
272 |
+
logger.info("Starting FinRL agent training")
|
273 |
+
|
274 |
+
# Prepare data
|
275 |
+
train_data = self.prepare_data(data)
|
276 |
+
|
277 |
+
# Create training environment
|
278 |
+
self.env = DummyVecEnv([lambda: self.create_environment(train_data)])
|
279 |
+
|
280 |
+
# Create evaluation environment if provided
|
281 |
+
if eval_data is not None:
|
282 |
+
eval_data = self.prepare_data(eval_data)
|
283 |
+
self.eval_env = DummyVecEnv([lambda: self.create_environment(eval_data)])
|
284 |
+
self.callback = EvalCallback(
|
285 |
+
self.eval_env,
|
286 |
+
best_model_save_path="models/finrl_best/",
|
287 |
+
log_path="logs/finrl_eval/",
|
288 |
+
eval_freq=eval_freq,
|
289 |
+
deterministic=True,
|
290 |
+
render=False
|
291 |
+
)
|
292 |
+
|
293 |
+
# Initialize model based on algorithm
|
294 |
+
if self.config.algorithm == "PPO":
|
295 |
+
self.model = PPO(
|
296 |
+
"MlpPolicy",
|
297 |
+
self.env,
|
298 |
+
learning_rate=self.config.learning_rate,
|
299 |
+
batch_size=self.config.batch_size,
|
300 |
+
gamma=self.config.gamma,
|
301 |
+
verbose=self.config.verbose,
|
302 |
+
tensorboard_log=self.config.tensorboard_log
|
303 |
+
)
|
304 |
+
elif self.config.algorithm == "A2C":
|
305 |
+
self.model = A2C(
|
306 |
+
"MlpPolicy",
|
307 |
+
self.env,
|
308 |
+
learning_rate=self.config.learning_rate,
|
309 |
+
gamma=self.config.gamma,
|
310 |
+
verbose=self.config.verbose,
|
311 |
+
tensorboard_log=self.config.tensorboard_log
|
312 |
+
)
|
313 |
+
elif self.config.algorithm == "DDPG":
|
314 |
+
self.model = DDPG(
|
315 |
+
"MlpPolicy",
|
316 |
+
self.env,
|
317 |
+
learning_rate=self.config.learning_rate,
|
318 |
+
buffer_size=self.config.buffer_size,
|
319 |
+
learning_starts=self.config.learning_starts,
|
320 |
+
gamma=self.config.gamma,
|
321 |
+
tau=self.config.tau,
|
322 |
+
train_freq=self.config.train_freq,
|
323 |
+
gradient_steps=self.config.gradient_steps,
|
324 |
+
verbose=self.config.verbose,
|
325 |
+
tensorboard_log=self.config.tensorboard_log
|
326 |
+
)
|
327 |
+
elif self.config.algorithm == "TD3":
|
328 |
+
self.model = TD3(
|
329 |
+
"MlpPolicy",
|
330 |
+
self.env,
|
331 |
+
learning_rate=self.config.learning_rate,
|
332 |
+
buffer_size=self.config.buffer_size,
|
333 |
+
learning_starts=self.config.learning_starts,
|
334 |
+
gamma=self.config.gamma,
|
335 |
+
tau=self.config.tau,
|
336 |
+
train_freq=self.config.train_freq,
|
337 |
+
gradient_steps=self.config.gradient_steps,
|
338 |
+
target_update_interval=self.config.target_update_interval,
|
339 |
+
verbose=self.config.verbose,
|
340 |
+
tensorboard_log=self.config.tensorboard_log
|
341 |
+
)
|
342 |
+
else:
|
343 |
+
raise ValueError(f"Unsupported algorithm: {self.config.algorithm}")
|
344 |
+
|
345 |
+
# Train the model
|
346 |
+
callbacks = [self.callback] if self.callback else None
|
347 |
+
self.model.learn(
|
348 |
+
total_timesteps=total_timesteps,
|
349 |
+
callback=callbacks
|
350 |
+
)
|
351 |
+
|
352 |
+
logger.info("FinRL agent training completed")
|
353 |
+
|
354 |
+
return {
|
355 |
+
'algorithm': self.config.algorithm,
|
356 |
+
'total_timesteps': total_timesteps,
|
357 |
+
'model_path': f"models/finrl_{self.config.algorithm.lower()}"
|
358 |
+
}
|
359 |
+
|
360 |
+
def predict(self, data: pd.DataFrame) -> List[int]:
|
361 |
+
"""Generate trading predictions using the trained model"""
|
362 |
+
if self.model is None:
|
363 |
+
raise ValueError("Model not trained. Call train() first.")
|
364 |
+
|
365 |
+
# Prepare data
|
366 |
+
test_data = self.prepare_data(data)
|
367 |
+
|
368 |
+
# Create test environment
|
369 |
+
test_env = self.create_environment(test_data)
|
370 |
+
|
371 |
+
predictions = []
|
372 |
+
obs, _ = test_env.reset()
|
373 |
+
|
374 |
+
done = False
|
375 |
+
while not done:
|
376 |
+
action, _ = self.model.predict(obs, deterministic=True)
|
377 |
+
predictions.append(action)
|
378 |
+
obs, _, done, _, _ = test_env.step(action)
|
379 |
+
|
380 |
+
return predictions
|
381 |
+
|
382 |
+
def evaluate(self, data: pd.DataFrame) -> Dict[str, float]:
|
383 |
+
"""Evaluate the trained model on test data"""
|
384 |
+
if self.model is None:
|
385 |
+
raise ValueError("Model not trained. Call train() first.")
|
386 |
+
|
387 |
+
# Prepare data
|
388 |
+
test_data = self.prepare_data(data)
|
389 |
+
|
390 |
+
# Create test environment
|
391 |
+
test_env = self.create_environment(test_data)
|
392 |
+
|
393 |
+
obs, _ = test_env.reset()
|
394 |
+
done = False
|
395 |
+
total_reward = 0
|
396 |
+
steps = 0
|
397 |
+
|
398 |
+
while not done:
|
399 |
+
action, _ = self.model.predict(obs, deterministic=True)
|
400 |
+
obs, reward, done, _, info = test_env.step(action)
|
401 |
+
total_reward += reward
|
402 |
+
steps += 1
|
403 |
+
|
404 |
+
# Calculate metrics
|
405 |
+
final_portfolio_value = info['portfolio_value']
|
406 |
+
initial_balance = test_env.initial_balance
|
407 |
+
total_return = (final_portfolio_value - initial_balance) / initial_balance
|
408 |
+
|
409 |
+
return {
|
410 |
+
'total_reward': total_reward,
|
411 |
+
'total_return': total_return,
|
412 |
+
'final_portfolio_value': final_portfolio_value,
|
413 |
+
'steps': steps,
|
414 |
+
'sharpe_ratio': total_reward / steps if steps > 0 else 0
|
415 |
+
}
|
416 |
+
|
417 |
+
def save_model(self, path: str):
|
418 |
+
"""Save the trained model"""
|
419 |
+
if self.model is None:
|
420 |
+
raise ValueError("No model to save. Train the model first.")
|
421 |
+
|
422 |
+
self.model.save(path)
|
423 |
+
logger.info(f"Model saved to {path}")
|
424 |
+
|
425 |
+
def load_model(self, path: str):
|
426 |
+
"""Load a trained model"""
|
427 |
+
if self.config.algorithm == "PPO":
|
428 |
+
self.model = PPO.load(path)
|
429 |
+
elif self.config.algorithm == "A2C":
|
430 |
+
self.model = A2C.load(path)
|
431 |
+
elif self.config.algorithm == "DDPG":
|
432 |
+
self.model = DDPG.load(path)
|
433 |
+
elif self.config.algorithm == "TD3":
|
434 |
+
self.model = TD3.load(path)
|
435 |
+
else:
|
436 |
+
raise ValueError(f"Unsupported algorithm: {self.config.algorithm}")
|
437 |
+
|
438 |
+
logger.info(f"Model loaded from {path}")
|
439 |
+
|
440 |
+
|
441 |
+
def create_finrl_agent_from_config(config_path: str) -> FinRLAgent:
|
442 |
+
"""Create FinRL agent from configuration file"""
|
443 |
+
with open(config_path, 'r') as file:
|
444 |
+
config_data = yaml.safe_load(file)
|
445 |
+
|
446 |
+
finrl_config = FinRLConfig(**config_data.get('finrl', {}))
|
447 |
+
return FinRLAgent(finrl_config)
|
config.yaml
CHANGED
@@ -33,3 +33,30 @@ logging:
|
|
33 |
enable_file: true
|
34 |
max_file_size_mb: 10
|
35 |
backup_count: 5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
enable_file: true
|
34 |
max_file_size_mb: 10
|
35 |
backup_count: 5
|
36 |
+
|
37 |
+
# FinRL configuration
|
38 |
+
finrl:
|
39 |
+
algorithm: 'PPO' # PPO, A2C, DDPG, TD3
|
40 |
+
learning_rate: 0.0003
|
41 |
+
batch_size: 64
|
42 |
+
buffer_size: 1000000
|
43 |
+
learning_starts: 100
|
44 |
+
gamma: 0.99
|
45 |
+
tau: 0.005
|
46 |
+
train_freq: 1
|
47 |
+
gradient_steps: 1
|
48 |
+
target_update_interval: 1
|
49 |
+
exploration_fraction: 0.1
|
50 |
+
exploration_initial_eps: 1.0
|
51 |
+
exploration_final_eps: 0.05
|
52 |
+
max_grad_norm: 10.0
|
53 |
+
verbose: 1
|
54 |
+
tensorboard_log: 'logs/finrl_tensorboard'
|
55 |
+
training:
|
56 |
+
total_timesteps: 100000
|
57 |
+
eval_freq: 10000
|
58 |
+
save_best_model: true
|
59 |
+
model_save_path: 'models/finrl_best/'
|
60 |
+
inference:
|
61 |
+
use_trained_model: false
|
62 |
+
model_path: 'models/finrl_best/best_model'
|
finrl_demo.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
FinRL Demo Script
|
4 |
+
|
5 |
+
This script demonstrates the integration of FinRL with the algorithmic trading system.
|
6 |
+
It shows how to train a reinforcement learning agent and use it for trading decisions.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import os
|
10 |
+
import sys
|
11 |
+
import yaml
|
12 |
+
import pandas as pd
|
13 |
+
import numpy as np
|
14 |
+
import matplotlib.pyplot as plt
|
15 |
+
import seaborn as sns
|
16 |
+
from datetime import datetime, timedelta
|
17 |
+
import logging
|
18 |
+
|
19 |
+
# Add the project root to the path
|
20 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
21 |
+
|
22 |
+
from agentic_ai_system.finrl_agent import FinRLAgent, FinRLConfig, create_finrl_agent_from_config
|
23 |
+
from agentic_ai_system.synthetic_data_generator import SyntheticDataGenerator
|
24 |
+
from agentic_ai_system.logger_config import setup_logging
|
25 |
+
|
26 |
+
# Setup logging
|
27 |
+
setup_logging()
|
28 |
+
logger = logging.getLogger(__name__)
|
29 |
+
|
30 |
+
|
31 |
+
def load_config(config_path: str = 'config.yaml') -> dict:
|
32 |
+
"""Load configuration from YAML file"""
|
33 |
+
with open(config_path, 'r') as file:
|
34 |
+
return yaml.safe_load(file)
|
35 |
+
|
36 |
+
|
37 |
+
def generate_training_data(config: dict) -> pd.DataFrame:
|
38 |
+
"""Generate synthetic data for training"""
|
39 |
+
logger.info("Generating synthetic training data")
|
40 |
+
|
41 |
+
generator = SyntheticDataGenerator(config)
|
42 |
+
|
43 |
+
# Generate training data (longer period)
|
44 |
+
train_data = generator.generate_ohlcv_data(
|
45 |
+
symbol='AAPL',
|
46 |
+
start_date='2023-01-01',
|
47 |
+
end_date='2023-12-31',
|
48 |
+
frequency='1H'
|
49 |
+
)
|
50 |
+
|
51 |
+
# Add technical indicators
|
52 |
+
train_data['sma_20'] = train_data['close'].rolling(window=20).mean()
|
53 |
+
train_data['sma_50'] = train_data['close'].rolling(window=50).mean()
|
54 |
+
train_data['rsi'] = calculate_rsi(train_data['close'])
|
55 |
+
bb_upper, bb_lower = calculate_bollinger_bands(train_data['close'])
|
56 |
+
train_data['bb_upper'] = bb_upper
|
57 |
+
train_data['bb_lower'] = bb_lower
|
58 |
+
train_data['macd'] = calculate_macd(train_data['close'])
|
59 |
+
|
60 |
+
# Fill NaN values
|
61 |
+
train_data = train_data.fillna(method='bfill').fillna(0)
|
62 |
+
|
63 |
+
logger.info(f"Generated {len(train_data)} training samples")
|
64 |
+
return train_data
|
65 |
+
|
66 |
+
|
67 |
+
def generate_test_data(config: dict) -> pd.DataFrame:
|
68 |
+
"""Generate synthetic data for testing"""
|
69 |
+
logger.info("Generating synthetic test data")
|
70 |
+
|
71 |
+
generator = SyntheticDataGenerator(config)
|
72 |
+
|
73 |
+
# Generate test data (shorter period)
|
74 |
+
test_data = generator.generate_ohlcv_data(
|
75 |
+
symbol='AAPL',
|
76 |
+
start_date='2024-01-01',
|
77 |
+
end_date='2024-03-31',
|
78 |
+
frequency='1H'
|
79 |
+
)
|
80 |
+
|
81 |
+
# Add technical indicators
|
82 |
+
test_data['sma_20'] = test_data['close'].rolling(window=20).mean()
|
83 |
+
test_data['sma_50'] = test_data['close'].rolling(window=50).mean()
|
84 |
+
test_data['rsi'] = calculate_rsi(test_data['close'])
|
85 |
+
bb_upper, bb_lower = calculate_bollinger_bands(test_data['close'])
|
86 |
+
test_data['bb_upper'] = bb_upper
|
87 |
+
test_data['bb_lower'] = bb_lower
|
88 |
+
test_data['macd'] = calculate_macd(test_data['close'])
|
89 |
+
|
90 |
+
# Fill NaN values
|
91 |
+
test_data = test_data.fillna(method='bfill').fillna(0)
|
92 |
+
|
93 |
+
logger.info(f"Generated {len(test_data)} test samples")
|
94 |
+
return test_data
|
95 |
+
|
96 |
+
|
97 |
+
def calculate_rsi(prices: pd.Series, period: int = 14) -> pd.Series:
|
98 |
+
"""Calculate RSI indicator"""
|
99 |
+
delta = prices.diff()
|
100 |
+
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
|
101 |
+
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
|
102 |
+
rs = gain / loss
|
103 |
+
rsi = 100 - (100 / (1 + rs))
|
104 |
+
return rsi
|
105 |
+
|
106 |
+
|
107 |
+
def calculate_bollinger_bands(prices: pd.Series, period: int = 20, std_dev: int = 2):
|
108 |
+
"""Calculate Bollinger Bands"""
|
109 |
+
sma = prices.rolling(window=period).mean()
|
110 |
+
std = prices.rolling(window=period).std()
|
111 |
+
upper_band = sma + (std * std_dev)
|
112 |
+
lower_band = sma - (std * std_dev)
|
113 |
+
return upper_band, lower_band
|
114 |
+
|
115 |
+
|
116 |
+
def calculate_macd(prices: pd.Series, fast: int = 12, slow: int = 26, signal: int = 9) -> pd.Series:
|
117 |
+
"""Calculate MACD indicator"""
|
118 |
+
ema_fast = prices.ewm(span=fast).mean()
|
119 |
+
ema_slow = prices.ewm(span=slow).mean()
|
120 |
+
macd_line = ema_fast - ema_slow
|
121 |
+
return macd_line
|
122 |
+
|
123 |
+
|
124 |
+
def train_finrl_agent(config: dict, train_data: pd.DataFrame, test_data: pd.DataFrame) -> FinRLAgent:
|
125 |
+
"""Train the FinRL agent"""
|
126 |
+
logger.info("Starting FinRL agent training")
|
127 |
+
|
128 |
+
# Create FinRL agent
|
129 |
+
finrl_config = FinRLConfig(**config['finrl'])
|
130 |
+
agent = FinRLAgent(finrl_config)
|
131 |
+
|
132 |
+
# Train the agent
|
133 |
+
training_result = agent.train(
|
134 |
+
data=train_data,
|
135 |
+
total_timesteps=config['finrl']['training']['total_timesteps'],
|
136 |
+
eval_freq=config['finrl']['training']['eval_freq'],
|
137 |
+
eval_data=test_data
|
138 |
+
)
|
139 |
+
|
140 |
+
logger.info(f"Training completed: {training_result}")
|
141 |
+
|
142 |
+
# Save the model
|
143 |
+
if config['finrl']['training']['save_best_model']:
|
144 |
+
model_path = config['finrl']['training']['model_save_path']
|
145 |
+
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
146 |
+
agent.save_model(model_path)
|
147 |
+
|
148 |
+
return agent
|
149 |
+
|
150 |
+
|
151 |
+
def evaluate_agent(agent: FinRLAgent, test_data: pd.DataFrame) -> dict:
|
152 |
+
"""Evaluate the trained agent"""
|
153 |
+
logger.info("Evaluating FinRL agent")
|
154 |
+
|
155 |
+
# Evaluate on test data
|
156 |
+
evaluation_results = agent.evaluate(test_data)
|
157 |
+
|
158 |
+
logger.info(f"Evaluation results: {evaluation_results}")
|
159 |
+
|
160 |
+
return evaluation_results
|
161 |
+
|
162 |
+
|
163 |
+
def generate_predictions(agent: FinRLAgent, test_data: pd.DataFrame) -> list:
|
164 |
+
"""Generate trading predictions"""
|
165 |
+
logger.info("Generating trading predictions")
|
166 |
+
|
167 |
+
predictions = agent.predict(test_data)
|
168 |
+
|
169 |
+
logger.info(f"Generated {len(predictions)} predictions")
|
170 |
+
|
171 |
+
return predictions
|
172 |
+
|
173 |
+
|
174 |
+
def plot_results(test_data: pd.DataFrame, predictions: list, evaluation_results: dict):
|
175 |
+
"""Plot trading results"""
|
176 |
+
logger.info("Creating visualization plots")
|
177 |
+
|
178 |
+
# Create figure with subplots
|
179 |
+
fig, axes = plt.subplots(3, 1, figsize=(15, 12))
|
180 |
+
|
181 |
+
# Plot 1: Price and predictions
|
182 |
+
axes[0].plot(test_data.index, test_data['close'], label='Close Price', alpha=0.7)
|
183 |
+
|
184 |
+
# Mark buy/sell signals
|
185 |
+
buy_signals = [i for i, pred in enumerate(predictions) if pred == 2]
|
186 |
+
sell_signals = [i for i, pred in enumerate(predictions) if pred == 0]
|
187 |
+
|
188 |
+
if buy_signals:
|
189 |
+
axes[0].scatter(test_data.index[buy_signals], test_data['close'].iloc[buy_signals],
|
190 |
+
color='green', marker='^', s=100, label='Buy Signal', alpha=0.8)
|
191 |
+
if sell_signals:
|
192 |
+
axes[0].scatter(test_data.index[sell_signals], test_data['close'].iloc[sell_signals],
|
193 |
+
color='red', marker='v', s=100, label='Sell Signal', alpha=0.8)
|
194 |
+
|
195 |
+
axes[0].set_title('Price Action and Trading Signals')
|
196 |
+
axes[0].set_ylabel('Price')
|
197 |
+
axes[0].legend()
|
198 |
+
axes[0].grid(True, alpha=0.3)
|
199 |
+
|
200 |
+
# Plot 2: Technical indicators
|
201 |
+
axes[1].plot(test_data.index, test_data['close'], label='Close Price', alpha=0.7)
|
202 |
+
axes[1].plot(test_data.index, test_data['sma_20'], label='SMA 20', alpha=0.7)
|
203 |
+
axes[1].plot(test_data.index, test_data['sma_50'], label='SMA 50', alpha=0.7)
|
204 |
+
axes[1].plot(test_data.index, test_data['bb_upper'], label='BB Upper', alpha=0.5)
|
205 |
+
axes[1].plot(test_data.index, test_data['bb_lower'], label='BB Lower', alpha=0.5)
|
206 |
+
|
207 |
+
axes[1].set_title('Technical Indicators')
|
208 |
+
axes[1].set_ylabel('Price')
|
209 |
+
axes[1].legend()
|
210 |
+
axes[1].grid(True, alpha=0.3)
|
211 |
+
|
212 |
+
# Plot 3: RSI
|
213 |
+
axes[2].plot(test_data.index, test_data['rsi'], label='RSI', color='purple')
|
214 |
+
axes[2].axhline(y=70, color='r', linestyle='--', alpha=0.5, label='Overbought')
|
215 |
+
axes[2].axhline(y=30, color='g', linestyle='--', alpha=0.5, label='Oversold')
|
216 |
+
axes[2].set_title('RSI Indicator')
|
217 |
+
axes[2].set_ylabel('RSI')
|
218 |
+
axes[2].set_xlabel('Time')
|
219 |
+
axes[2].legend()
|
220 |
+
axes[2].grid(True, alpha=0.3)
|
221 |
+
|
222 |
+
plt.tight_layout()
|
223 |
+
|
224 |
+
# Save plot
|
225 |
+
os.makedirs('plots', exist_ok=True)
|
226 |
+
plt.savefig('plots/finrl_trading_results.png', dpi=300, bbox_inches='tight')
|
227 |
+
plt.show()
|
228 |
+
|
229 |
+
logger.info("Plots saved to plots/finrl_trading_results.png")
|
230 |
+
|
231 |
+
|
232 |
+
def print_summary(evaluation_results: dict, predictions: list):
|
233 |
+
"""Print trading summary"""
|
234 |
+
print("\n" + "="*60)
|
235 |
+
print("FINRL TRADING SYSTEM SUMMARY")
|
236 |
+
print("="*60)
|
237 |
+
|
238 |
+
print(f"Algorithm: {evaluation_results.get('algorithm', 'Unknown')}")
|
239 |
+
print(f"Total Return: {evaluation_results['total_return']:.2%}")
|
240 |
+
print(f"Final Portfolio Value: ${evaluation_results['final_portfolio_value']:,.2f}")
|
241 |
+
print(f"Total Reward: {evaluation_results['total_reward']:.4f}")
|
242 |
+
print(f"Sharpe Ratio: {evaluation_results['sharpe_ratio']:.4f}")
|
243 |
+
print(f"Number of Trading Steps: {evaluation_results['steps']}")
|
244 |
+
|
245 |
+
# Trading statistics
|
246 |
+
buy_signals = sum(1 for pred in predictions if pred == 2)
|
247 |
+
sell_signals = sum(1 for pred in predictions if pred == 0)
|
248 |
+
hold_signals = sum(1 for pred in predictions if pred == 1)
|
249 |
+
|
250 |
+
print(f"\nTrading Signals:")
|
251 |
+
print(f" Buy signals: {buy_signals}")
|
252 |
+
print(f" Sell signals: {sell_signals}")
|
253 |
+
print(f" Hold signals: {hold_signals}")
|
254 |
+
print(f" Total signals: {len(predictions)}")
|
255 |
+
|
256 |
+
print("\n" + "="*60)
|
257 |
+
|
258 |
+
|
259 |
+
def main():
|
260 |
+
"""Main function to run the FinRL demo"""
|
261 |
+
logger.info("Starting FinRL Demo")
|
262 |
+
|
263 |
+
try:
|
264 |
+
# Load configuration
|
265 |
+
config = load_config()
|
266 |
+
|
267 |
+
# Generate data
|
268 |
+
train_data = generate_training_data(config)
|
269 |
+
test_data = generate_test_data(config)
|
270 |
+
|
271 |
+
# Train FinRL agent
|
272 |
+
agent = train_finrl_agent(config, train_data, test_data)
|
273 |
+
|
274 |
+
# Evaluate agent
|
275 |
+
evaluation_results = evaluate_agent(agent, test_data)
|
276 |
+
|
277 |
+
# Generate predictions
|
278 |
+
predictions = generate_predictions(agent, test_data)
|
279 |
+
|
280 |
+
# Create visualizations
|
281 |
+
plot_results(test_data, predictions, evaluation_results)
|
282 |
+
|
283 |
+
# Print summary
|
284 |
+
print_summary(evaluation_results, predictions)
|
285 |
+
|
286 |
+
logger.info("FinRL Demo completed successfully")
|
287 |
+
|
288 |
+
except Exception as e:
|
289 |
+
logger.error(f"Error in FinRL demo: {str(e)}")
|
290 |
+
raise
|
291 |
+
|
292 |
+
|
293 |
+
if __name__ == "__main__":
|
294 |
+
main()
|
requirements.txt
CHANGED
@@ -7,3 +7,8 @@ pytest
|
|
7 |
pytest-cov
|
8 |
python-dateutil
|
9 |
scipy
|
|
|
|
|
|
|
|
|
|
|
|
7 |
pytest-cov
|
8 |
python-dateutil
|
9 |
scipy
|
10 |
+
finrl
|
11 |
+
stable-baselines3
|
12 |
+
gymnasium
|
13 |
+
tensorboard
|
14 |
+
torch
|
tests/test_finrl_agent.py
ADDED
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Tests for FinRL Agent
|
3 |
+
|
4 |
+
This module contains comprehensive tests for the FinRL agent functionality.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import pytest
|
8 |
+
import pandas as pd
|
9 |
+
import numpy as np
|
10 |
+
import yaml
|
11 |
+
import tempfile
|
12 |
+
import os
|
13 |
+
from unittest.mock import Mock, patch
|
14 |
+
|
15 |
+
# Add the project root to the path
|
16 |
+
import sys
|
17 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
18 |
+
|
19 |
+
from agentic_ai_system.finrl_agent import (
|
20 |
+
FinRLAgent,
|
21 |
+
FinRLConfig,
|
22 |
+
TradingEnvironment,
|
23 |
+
create_finrl_agent_from_config
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
class TestFinRLConfig:
|
28 |
+
"""Test FinRL configuration"""
|
29 |
+
|
30 |
+
def test_default_config(self):
|
31 |
+
"""Test default configuration values"""
|
32 |
+
config = FinRLConfig()
|
33 |
+
|
34 |
+
assert config.algorithm == "PPO"
|
35 |
+
assert config.learning_rate == 0.0003
|
36 |
+
assert config.batch_size == 64
|
37 |
+
assert config.gamma == 0.99
|
38 |
+
|
39 |
+
def test_custom_config(self):
|
40 |
+
"""Test custom configuration values"""
|
41 |
+
config = FinRLConfig(
|
42 |
+
algorithm="A2C",
|
43 |
+
learning_rate=0.001,
|
44 |
+
batch_size=128
|
45 |
+
)
|
46 |
+
|
47 |
+
assert config.algorithm == "A2C"
|
48 |
+
assert config.learning_rate == 0.001
|
49 |
+
assert config.batch_size == 128
|
50 |
+
|
51 |
+
|
52 |
+
class TestTradingEnvironment:
|
53 |
+
"""Test trading environment"""
|
54 |
+
|
55 |
+
@pytest.fixture
|
56 |
+
def sample_data(self):
|
57 |
+
"""Create sample market data"""
|
58 |
+
dates = pd.date_range('2024-01-01', periods=100, freq='1H')
|
59 |
+
data = pd.DataFrame({
|
60 |
+
'open': np.random.uniform(100, 200, 100),
|
61 |
+
'high': np.random.uniform(100, 200, 100),
|
62 |
+
'low': np.random.uniform(100, 200, 100),
|
63 |
+
'close': np.random.uniform(100, 200, 100),
|
64 |
+
'volume': np.random.uniform(1000, 10000, 100),
|
65 |
+
'sma_20': np.random.uniform(100, 200, 100),
|
66 |
+
'sma_50': np.random.uniform(100, 200, 100),
|
67 |
+
'rsi': np.random.uniform(0, 100, 100),
|
68 |
+
'bb_upper': np.random.uniform(100, 200, 100),
|
69 |
+
'bb_lower': np.random.uniform(100, 200, 100),
|
70 |
+
'macd': np.random.uniform(-10, 10, 100)
|
71 |
+
}, index=dates)
|
72 |
+
return data
|
73 |
+
|
74 |
+
def test_environment_initialization(self, sample_data):
|
75 |
+
"""Test environment initialization"""
|
76 |
+
env = TradingEnvironment(sample_data)
|
77 |
+
|
78 |
+
assert env.initial_balance == 100000
|
79 |
+
assert env.transaction_fee == 0.001
|
80 |
+
assert env.max_position == 100
|
81 |
+
assert env.action_space.n == 3
|
82 |
+
assert len(env.observation_space.shape) == 1
|
83 |
+
|
84 |
+
def test_environment_reset(self, sample_data):
|
85 |
+
"""Test environment reset"""
|
86 |
+
env = TradingEnvironment(sample_data)
|
87 |
+
obs, info = env.reset()
|
88 |
+
|
89 |
+
assert env.current_step == 0
|
90 |
+
assert env.balance == env.initial_balance
|
91 |
+
assert env.position == 0
|
92 |
+
assert env.portfolio_value == env.initial_balance
|
93 |
+
assert isinstance(obs, np.ndarray)
|
94 |
+
assert isinstance(info, dict)
|
95 |
+
|
96 |
+
def test_environment_step(self, sample_data):
|
97 |
+
"""Test environment step"""
|
98 |
+
env = TradingEnvironment(sample_data)
|
99 |
+
obs, info = env.reset()
|
100 |
+
|
101 |
+
# Test hold action
|
102 |
+
obs, reward, done, truncated, info = env.step(1)
|
103 |
+
|
104 |
+
assert isinstance(obs, np.ndarray)
|
105 |
+
assert isinstance(reward, float)
|
106 |
+
assert isinstance(done, bool)
|
107 |
+
assert isinstance(truncated, bool)
|
108 |
+
assert isinstance(info, dict)
|
109 |
+
assert env.current_step == 1
|
110 |
+
|
111 |
+
def test_buy_action(self, sample_data):
|
112 |
+
"""Test buy action"""
|
113 |
+
env = TradingEnvironment(sample_data, initial_balance=10000)
|
114 |
+
obs, info = env.reset()
|
115 |
+
|
116 |
+
initial_balance = env.balance
|
117 |
+
initial_position = env.position
|
118 |
+
|
119 |
+
# Buy action
|
120 |
+
obs, reward, done, truncated, info = env.step(2)
|
121 |
+
|
122 |
+
assert env.position > initial_position
|
123 |
+
assert env.balance < initial_balance
|
124 |
+
|
125 |
+
def test_sell_action(self, sample_data):
|
126 |
+
"""Test sell action"""
|
127 |
+
env = TradingEnvironment(sample_data, initial_balance=10000)
|
128 |
+
obs, info = env.reset()
|
129 |
+
|
130 |
+
# First buy some shares
|
131 |
+
obs, reward, done, truncated, info = env.step(2)
|
132 |
+
initial_position = env.position
|
133 |
+
initial_balance = env.balance
|
134 |
+
|
135 |
+
# Then sell
|
136 |
+
obs, reward, done, truncated, info = env.step(0)
|
137 |
+
|
138 |
+
assert env.position < initial_position
|
139 |
+
assert env.balance > initial_balance
|
140 |
+
|
141 |
+
def test_portfolio_value_calculation(self, sample_data):
|
142 |
+
"""Test portfolio value calculation"""
|
143 |
+
env = TradingEnvironment(sample_data)
|
144 |
+
obs, info = env.reset()
|
145 |
+
|
146 |
+
# Buy some shares
|
147 |
+
obs, reward, done, truncated, info = env.step(2)
|
148 |
+
|
149 |
+
expected_value = env.balance + (env.position * sample_data.iloc[env.current_step]['close'])
|
150 |
+
assert abs(env.portfolio_value - expected_value) < 1e-6
|
151 |
+
|
152 |
+
|
153 |
+
class TestFinRLAgent:
|
154 |
+
"""Test FinRL agent"""
|
155 |
+
|
156 |
+
@pytest.fixture
|
157 |
+
def sample_data(self):
|
158 |
+
"""Create sample market data"""
|
159 |
+
dates = pd.date_range('2024-01-01', periods=100, freq='1H')
|
160 |
+
data = pd.DataFrame({
|
161 |
+
'open': np.random.uniform(100, 200, 100),
|
162 |
+
'high': np.random.uniform(100, 200, 100),
|
163 |
+
'low': np.random.uniform(100, 200, 100),
|
164 |
+
'close': np.random.uniform(100, 200, 100),
|
165 |
+
'volume': np.random.uniform(1000, 10000, 100)
|
166 |
+
}, index=dates)
|
167 |
+
return data
|
168 |
+
|
169 |
+
@pytest.fixture
|
170 |
+
def finrl_config(self):
|
171 |
+
"""Create FinRL configuration"""
|
172 |
+
return FinRLConfig(
|
173 |
+
algorithm="PPO",
|
174 |
+
learning_rate=0.0003,
|
175 |
+
batch_size=32,
|
176 |
+
total_timesteps=1000
|
177 |
+
)
|
178 |
+
|
179 |
+
def test_agent_initialization(self, finrl_config):
|
180 |
+
"""Test agent initialization"""
|
181 |
+
agent = FinRLAgent(finrl_config)
|
182 |
+
|
183 |
+
assert agent.config == finrl_config
|
184 |
+
assert agent.model is None
|
185 |
+
assert agent.env is None
|
186 |
+
|
187 |
+
def test_prepare_data(self, finrl_config, sample_data):
|
188 |
+
"""Test data preparation"""
|
189 |
+
agent = FinRLAgent(finrl_config)
|
190 |
+
prepared_data = agent.prepare_data(sample_data)
|
191 |
+
|
192 |
+
# Check that technical indicators were added
|
193 |
+
assert 'sma_20' in prepared_data.columns
|
194 |
+
assert 'sma_50' in prepared_data.columns
|
195 |
+
assert 'rsi' in prepared_data.columns
|
196 |
+
assert 'bb_upper' in prepared_data.columns
|
197 |
+
assert 'bb_lower' in prepared_data.columns
|
198 |
+
assert 'macd' in prepared_data.columns
|
199 |
+
|
200 |
+
# Check that no NaN values remain
|
201 |
+
assert not prepared_data.isnull().any().any()
|
202 |
+
|
203 |
+
def test_create_environment(self, finrl_config, sample_data):
|
204 |
+
"""Test environment creation"""
|
205 |
+
agent = FinRLAgent(finrl_config)
|
206 |
+
env = agent.create_environment(sample_data)
|
207 |
+
|
208 |
+
assert isinstance(env, TradingEnvironment)
|
209 |
+
assert env.data.equals(sample_data)
|
210 |
+
|
211 |
+
def test_technical_indicators_calculation(self, finrl_config):
|
212 |
+
"""Test technical indicators calculation"""
|
213 |
+
agent = FinRLAgent(finrl_config)
|
214 |
+
|
215 |
+
# Test RSI calculation
|
216 |
+
prices = pd.Series([100, 101, 99, 102, 98, 103, 97, 104, 96, 105])
|
217 |
+
rsi = agent._calculate_rsi(prices, period=3)
|
218 |
+
assert len(rsi) == len(prices)
|
219 |
+
assert not rsi.isnull().all()
|
220 |
+
|
221 |
+
# Test Bollinger Bands calculation
|
222 |
+
bb_upper, bb_lower = agent._calculate_bollinger_bands(prices, period=3)
|
223 |
+
assert len(bb_upper) == len(prices)
|
224 |
+
assert len(bb_lower) == len(prices)
|
225 |
+
assert (bb_upper >= bb_lower).all()
|
226 |
+
|
227 |
+
# Test MACD calculation
|
228 |
+
macd = agent._calculate_macd(prices)
|
229 |
+
assert len(macd) == len(prices)
|
230 |
+
|
231 |
+
@patch('agentic_ai_system.finrl_agent.PPO')
|
232 |
+
def test_training_ppo(self, mock_ppo, finrl_config, sample_data):
|
233 |
+
"""Test PPO training"""
|
234 |
+
# Mock the PPO model
|
235 |
+
mock_model = Mock()
|
236 |
+
mock_ppo.return_value = mock_model
|
237 |
+
|
238 |
+
agent = FinRLAgent(finrl_config)
|
239 |
+
result = agent.train(sample_data, total_timesteps=100)
|
240 |
+
|
241 |
+
assert result['algorithm'] == 'PPO'
|
242 |
+
assert result['total_timesteps'] == 100
|
243 |
+
mock_model.learn.assert_called_once()
|
244 |
+
|
245 |
+
@patch('agentic_ai_system.finrl_agent.A2C')
|
246 |
+
def test_training_a2c(self, mock_a2c):
|
247 |
+
"""Test A2C training"""
|
248 |
+
config = FinRLConfig(algorithm="A2C")
|
249 |
+
mock_model = Mock()
|
250 |
+
mock_a2c.return_value = mock_model
|
251 |
+
|
252 |
+
agent = FinRLAgent(config)
|
253 |
+
sample_data = pd.DataFrame({
|
254 |
+
'open': [100, 101, 102],
|
255 |
+
'high': [101, 102, 103],
|
256 |
+
'low': [99, 100, 101],
|
257 |
+
'close': [100, 101, 102],
|
258 |
+
'volume': [1000, 1100, 1200]
|
259 |
+
})
|
260 |
+
|
261 |
+
result = agent.train(sample_data, total_timesteps=100)
|
262 |
+
|
263 |
+
assert result['algorithm'] == 'A2C'
|
264 |
+
mock_model.learn.assert_called_once()
|
265 |
+
|
266 |
+
def test_invalid_algorithm(self):
|
267 |
+
"""Test invalid algorithm handling"""
|
268 |
+
config = FinRLConfig(algorithm="INVALID")
|
269 |
+
agent = FinRLAgent(config)
|
270 |
+
sample_data = pd.DataFrame({
|
271 |
+
'open': [100, 101, 102],
|
272 |
+
'high': [101, 102, 103],
|
273 |
+
'low': [99, 100, 101],
|
274 |
+
'close': [100, 101, 102],
|
275 |
+
'volume': [1000, 1100, 1200]
|
276 |
+
})
|
277 |
+
|
278 |
+
with pytest.raises(ValueError, match="Unsupported algorithm"):
|
279 |
+
agent.train(sample_data, total_timesteps=100)
|
280 |
+
|
281 |
+
def test_predict_without_training(self, finrl_config, sample_data):
|
282 |
+
"""Test prediction without training"""
|
283 |
+
agent = FinRLAgent(finrl_config)
|
284 |
+
|
285 |
+
with pytest.raises(ValueError, match="Model not trained"):
|
286 |
+
agent.predict(sample_data)
|
287 |
+
|
288 |
+
def test_evaluate_without_training(self, finrl_config, sample_data):
|
289 |
+
"""Test evaluation without training"""
|
290 |
+
agent = FinRLAgent(finrl_config)
|
291 |
+
|
292 |
+
with pytest.raises(ValueError, match="Model not trained"):
|
293 |
+
agent.evaluate(sample_data)
|
294 |
+
|
295 |
+
@patch('agentic_ai_system.finrl_agent.PPO')
|
296 |
+
def test_save_and_load_model(self, mock_ppo, finrl_config, sample_data):
|
297 |
+
"""Test model saving and loading"""
|
298 |
+
# Mock the PPO model
|
299 |
+
mock_model = Mock()
|
300 |
+
mock_ppo.return_value = mock_model
|
301 |
+
mock_ppo.load.return_value = mock_model
|
302 |
+
|
303 |
+
agent = FinRLAgent(finrl_config)
|
304 |
+
|
305 |
+
# Train the agent
|
306 |
+
agent.train(sample_data, total_timesteps=100)
|
307 |
+
|
308 |
+
# Test saving
|
309 |
+
with tempfile.NamedTemporaryFile(suffix='.zip', delete=False) as tmp_file:
|
310 |
+
agent.save_model(tmp_file.name)
|
311 |
+
mock_model.save.assert_called_once_with(tmp_file.name)
|
312 |
+
|
313 |
+
# Test loading
|
314 |
+
agent.load_model(tmp_file.name)
|
315 |
+
mock_ppo.load.assert_called_once_with(tmp_file.name)
|
316 |
+
|
317 |
+
# Clean up
|
318 |
+
os.unlink(tmp_file.name)
|
319 |
+
|
320 |
+
|
321 |
+
class TestFinRLIntegration:
|
322 |
+
"""Test FinRL integration with configuration"""
|
323 |
+
|
324 |
+
def test_create_agent_from_config(self):
|
325 |
+
"""Test creating agent from configuration file"""
|
326 |
+
config_data = {
|
327 |
+
'finrl': {
|
328 |
+
'algorithm': 'PPO',
|
329 |
+
'learning_rate': 0.001,
|
330 |
+
'batch_size': 128,
|
331 |
+
'gamma': 0.95
|
332 |
+
}
|
333 |
+
}
|
334 |
+
|
335 |
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as tmp_file:
|
336 |
+
yaml.dump(config_data, tmp_file)
|
337 |
+
tmp_file_path = tmp_file.name
|
338 |
+
|
339 |
+
try:
|
340 |
+
agent = create_finrl_agent_from_config(tmp_file_path)
|
341 |
+
|
342 |
+
assert agent.config.algorithm == 'PPO'
|
343 |
+
assert agent.config.learning_rate == 0.001
|
344 |
+
assert agent.config.batch_size == 128
|
345 |
+
assert agent.config.gamma == 0.95
|
346 |
+
finally:
|
347 |
+
os.unlink(tmp_file_path)
|
348 |
+
|
349 |
+
def test_create_agent_from_config_missing_finrl(self):
|
350 |
+
"""Test creating agent from config without finrl section"""
|
351 |
+
config_data = {
|
352 |
+
'trading': {
|
353 |
+
'symbol': 'AAPL',
|
354 |
+
'capital': 100000
|
355 |
+
}
|
356 |
+
}
|
357 |
+
|
358 |
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as tmp_file:
|
359 |
+
yaml.dump(config_data, tmp_file)
|
360 |
+
tmp_file_path = tmp_file.name
|
361 |
+
|
362 |
+
try:
|
363 |
+
agent = create_finrl_agent_from_config(tmp_file_path)
|
364 |
+
|
365 |
+
# Should use default values
|
366 |
+
assert agent.config.algorithm == 'PPO'
|
367 |
+
assert agent.config.learning_rate == 0.0003
|
368 |
+
finally:
|
369 |
+
os.unlink(tmp_file_path)
|
370 |
+
|
371 |
+
|
372 |
+
if __name__ == "__main__":
|
373 |
+
pytest.main([__file__])
|