Edwin Salguero commited on
Commit
2c67d05
·
1 Parent(s): 859af74

Add FinRL integration with comprehensive RL trading agent

Browse files

- Add FinRL agent with support for PPO, A2C, DDPG, and TD3 algorithms
- Create custom trading environment compatible with Gymnasium
- Implement technical indicators integration (RSI, Bollinger Bands, MACD)
- Add comprehensive configuration system for FinRL parameters
- Create demo script with training, evaluation, and visualization
- Add comprehensive test suite for FinRL functionality
- Update requirements.txt with FinRL dependencies
- Update README with detailed FinRL documentation
- Create necessary directories for models, logs, and plots

README.md CHANGED
@@ -1,6 +1,6 @@
1
  # Algorithmic Trading System
2
 
3
- A comprehensive algorithmic trading system with synthetic data generation, comprehensive logging, and extensive testing capabilities.
4
 
5
  ## Features
6
 
@@ -10,6 +10,15 @@ A comprehensive algorithmic trading system with synthetic data generation, compr
10
  - **Risk Management**: Position sizing and drawdown limits
11
  - **Order Execution**: Simulated broker integration with realistic execution delays
12
 
 
 
 
 
 
 
 
 
 
13
  ### Synthetic Data Generation
14
  - **Realistic Market Data**: Generate OHLCV data using geometric Brownian motion
15
  - **Multiple Frequencies**: Support for 1min, 5min, 1H, and 1D data
@@ -226,6 +235,129 @@ logger.warning("High volatility detected")
226
  logger.error("Order execution failed", exc_info=True)
227
  ```
228
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  ## Testing
230
 
231
  ### Test Structure
 
1
  # Algorithmic Trading System
2
 
3
+ A comprehensive algorithmic trading system with synthetic data generation, comprehensive logging, extensive testing capabilities, and FinRL reinforcement learning integration.
4
 
5
  ## Features
6
 
 
10
  - **Risk Management**: Position sizing and drawdown limits
11
  - **Order Execution**: Simulated broker integration with realistic execution delays
12
 
13
+ ### FinRL Reinforcement Learning
14
+ - **Multiple RL Algorithms**: Support for PPO, A2C, DDPG, and TD3
15
+ - **Custom Trading Environment**: Gymnasium-compatible environment for RL training
16
+ - **Technical Indicators Integration**: Automatic calculation and inclusion of technical indicators
17
+ - **Portfolio Management**: Realistic portfolio simulation with transaction costs
18
+ - **Model Persistence**: Save and load trained models for inference
19
+ - **TensorBoard Integration**: Training progress visualization and monitoring
20
+ - **Comprehensive Evaluation**: Performance metrics including Sharpe ratio and total returns
21
+
22
  ### Synthetic Data Generation
23
  - **Realistic Market Data**: Generate OHLCV data using geometric Brownian motion
24
  - **Multiple Frequencies**: Support for 1min, 5min, 1H, and 1D data
 
235
  logger.error("Order execution failed", exc_info=True)
236
  ```
237
 
238
+ ## FinRL Integration
239
+
240
+ ### Overview
241
+ The system now includes FinRL (Financial Reinforcement Learning) integration, providing state-of-the-art reinforcement learning capabilities for algorithmic trading. The FinRL agent can learn optimal trading strategies through interaction with a simulated market environment.
242
+
243
+ ### Supported Algorithms
244
+ - **PPO (Proximal Policy Optimization)**: Stable policy gradient method
245
+ - **A2C (Advantage Actor-Critic)**: Actor-critic method with advantage estimation
246
+ - **DDPG (Deep Deterministic Policy Gradient)**: Continuous action space algorithm
247
+ - **TD3 (Twin Delayed DDPG)**: Improved version of DDPG with twin critics
248
+
249
+ ### Trading Environment
250
+ The custom trading environment provides:
251
+ - **Action Space**: Discrete actions (0=Buy, 1=Hold, 2=Sell)
252
+ - **Observation Space**: OHLCV data + technical indicators + portfolio state
253
+ - **Reward Function**: Portfolio return-based rewards
254
+ - **Transaction Costs**: Realistic trading fees and slippage
255
+ - **Position Limits**: Maximum position constraints
256
+
257
+ ### Usage Examples
258
+
259
+ #### Basic FinRL Training
260
+ ```python
261
+ from agentic_ai_system.finrl_agent import FinRLAgent, FinRLConfig
262
+ import pandas as pd
263
+
264
+ # Create configuration
265
+ config = FinRLConfig(
266
+ algorithm="PPO",
267
+ learning_rate=0.0003,
268
+ batch_size=64,
269
+ total_timesteps=100000
270
+ )
271
+
272
+ # Initialize agent
273
+ agent = FinRLAgent(config)
274
+
275
+ # Train the agent
276
+ training_result = agent.train(
277
+ data=market_data,
278
+ total_timesteps=100000,
279
+ eval_freq=10000
280
+ )
281
+
282
+ # Generate predictions
283
+ predictions = agent.predict(test_data)
284
+
285
+ # Evaluate performance
286
+ evaluation = agent.evaluate(test_data)
287
+ print(f"Total Return: {evaluation['total_return']:.2%}")
288
+ ```
289
+
290
+ #### Using Configuration File
291
+ ```python
292
+ from agentic_ai_system.finrl_agent import create_finrl_agent_from_config
293
+
294
+ # Create agent from config file
295
+ agent = create_finrl_agent_from_config('config.yaml')
296
+
297
+ # Train and evaluate
298
+ agent.train(market_data)
299
+ results = agent.evaluate(test_data)
300
+ ```
301
+
302
+ #### Running FinRL Demo
303
+ ```bash
304
+ # Run the complete FinRL demo
305
+ python finrl_demo.py
306
+
307
+ # This will:
308
+ # 1. Generate synthetic training and test data
309
+ # 2. Train a FinRL agent
310
+ # 3. Evaluate performance
311
+ # 4. Generate trading predictions
312
+ # 5. Create visualization plots
313
+ ```
314
+
315
+ ### Configuration
316
+ FinRL settings can be configured in `config.yaml`:
317
+
318
+ ```yaml
319
+ finrl:
320
+ algorithm: 'PPO' # PPO, A2C, DDPG, TD3
321
+ learning_rate: 0.0003
322
+ batch_size: 64
323
+ buffer_size: 1000000
324
+ gamma: 0.99
325
+ tensorboard_log: 'logs/finrl_tensorboard'
326
+ training:
327
+ total_timesteps: 100000
328
+ eval_freq: 10000
329
+ save_best_model: true
330
+ model_save_path: 'models/finrl_best/'
331
+ inference:
332
+ use_trained_model: false
333
+ model_path: 'models/finrl_best/best_model'
334
+ ```
335
+
336
+ ### Model Management
337
+ ```python
338
+ # Save trained model
339
+ agent.save_model('models/my_finrl_model')
340
+
341
+ # Load pre-trained model
342
+ agent.load_model('models/my_finrl_model')
343
+
344
+ # Continue training
345
+ agent.train(more_data, total_timesteps=50000)
346
+ ```
347
+
348
+ ### Performance Monitoring
349
+ - **TensorBoard Integration**: Monitor training progress
350
+ - **Evaluation Metrics**: Total return, Sharpe ratio, portfolio value
351
+ - **Trading Statistics**: Buy/sell signal analysis
352
+ - **Visualization**: Price charts with trading signals
353
+
354
+ ### Advanced Features
355
+ - **Multi-timeframe Support**: Train on different data frequencies
356
+ - **Feature Engineering**: Automatic technical indicator calculation
357
+ - **Risk Management**: Built-in position and drawdown limits
358
+ - **Backtesting**: Comprehensive backtesting capabilities
359
+ - **Hyperparameter Tuning**: Easy configuration for different algorithms
360
+
361
  ## Testing
362
 
363
  ### Test Structure
agentic_ai_system/finrl_agent.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FinRL Agent for Algorithmic Trading
3
+
4
+ This module provides a FinRL-based reinforcement learning agent that can be integrated
5
+ with the existing algorithmic trading system. It supports various RL algorithms
6
+ including PPO, A2C, DDPG, and TD3.
7
+ """
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+ import gymnasium as gym
12
+ from gymnasium import spaces
13
+ from stable_baselines3 import PPO, A2C, DDPG, TD3
14
+ from stable_baselines3.common.vec_env import DummyVecEnv
15
+ from stable_baselines3.common.callbacks import EvalCallback
16
+ import torch
17
+ import logging
18
+ from typing import Dict, List, Tuple, Optional, Any
19
+ from dataclasses import dataclass
20
+ import yaml
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ @dataclass
26
+ class FinRLConfig:
27
+ """Configuration for FinRL agent"""
28
+ algorithm: str = "PPO" # PPO, A2C, DDPG, TD3
29
+ learning_rate: float = 0.0003
30
+ batch_size: int = 64
31
+ buffer_size: int = 1000000
32
+ learning_starts: int = 100
33
+ gamma: float = 0.99
34
+ tau: float = 0.005
35
+ train_freq: int = 1
36
+ gradient_steps: int = 1
37
+ target_update_interval: int = 1
38
+ exploration_fraction: float = 0.1
39
+ exploration_initial_eps: float = 1.0
40
+ exploration_final_eps: float = 0.05
41
+ max_grad_norm: float = 10.0
42
+ verbose: int = 1
43
+ tensorboard_log: str = "logs/finrl_tensorboard"
44
+
45
+
46
+ class TradingEnvironment(gym.Env):
47
+ """
48
+ Custom trading environment for FinRL
49
+
50
+ This environment simulates a trading scenario where the agent can:
51
+ - Buy, sell, or hold positions
52
+ - Use technical indicators for decision making
53
+ - Manage portfolio value and risk
54
+ """
55
+
56
+ def __init__(self, data: pd.DataFrame, initial_balance: float = 100000,
57
+ transaction_fee: float = 0.001, max_position: int = 100):
58
+ super().__init__()
59
+
60
+ self.data = data
61
+ self.initial_balance = initial_balance
62
+ self.transaction_fee = transaction_fee
63
+ self.max_position = max_position
64
+
65
+ # Reset state
66
+ self.reset()
67
+
68
+ # Define action space: [-1, 0, 1] for sell, hold, buy
69
+ self.action_space = spaces.Discrete(3)
70
+
71
+ # Define observation space
72
+ # Features: OHLCV + technical indicators + portfolio state
73
+ n_features = len(self._get_features(self.data.iloc[0]))
74
+ self.observation_space = spaces.Box(
75
+ low=-np.inf, high=np.inf, shape=(n_features,), dtype=np.float32
76
+ )
77
+
78
+ def _get_features(self, row: pd.Series) -> np.ndarray:
79
+ """Extract features from market data row"""
80
+ features = []
81
+
82
+ # Price features
83
+ features.extend([
84
+ row['open'], row['high'], row['low'], row['close'], row['volume']
85
+ ])
86
+
87
+ # Technical indicators (if available)
88
+ for indicator in ['sma_20', 'sma_50', 'rsi', 'bb_upper', 'bb_lower', 'macd']:
89
+ if indicator in row.index:
90
+ features.append(row[indicator])
91
+ else:
92
+ features.append(0.0)
93
+
94
+ # Portfolio state
95
+ features.extend([
96
+ self.balance,
97
+ self.position,
98
+ self.portfolio_value,
99
+ self.total_return
100
+ ])
101
+
102
+ return np.array(features, dtype=np.float32)
103
+
104
+ def _calculate_portfolio_value(self) -> float:
105
+ """Calculate current portfolio value"""
106
+ current_price = self.data.iloc[self.current_step]['close']
107
+ return self.balance + (self.position * current_price)
108
+
109
+ def _calculate_reward(self) -> float:
110
+ """Calculate reward based on portfolio performance"""
111
+ current_value = self._calculate_portfolio_value()
112
+ previous_value = self.previous_portfolio_value
113
+
114
+ # Calculate return
115
+ if previous_value > 0:
116
+ return (current_value - previous_value) / previous_value
117
+ else:
118
+ return 0.0
119
+
120
+ def step(self, action: int) -> Tuple[np.ndarray, float, bool, bool, Dict]:
121
+ """Execute one step in the environment"""
122
+
123
+ # Get current market data
124
+ current_data = self.data.iloc[self.current_step]
125
+ current_price = current_data['close']
126
+
127
+ # Execute action
128
+ if action == 0: # Sell
129
+ if self.position > 0:
130
+ shares_to_sell = min(self.position, self.max_position)
131
+ sell_value = shares_to_sell * current_price * (1 - self.transaction_fee)
132
+ self.balance += sell_value
133
+ self.position -= shares_to_sell
134
+ elif action == 2: # Buy
135
+ if self.balance > 0:
136
+ max_shares = min(
137
+ int(self.balance / current_price),
138
+ self.max_position - self.position
139
+ )
140
+ if max_shares > 0:
141
+ buy_value = max_shares * current_price * (1 + self.transaction_fee)
142
+ self.balance -= buy_value
143
+ self.position += max_shares
144
+
145
+ # Update portfolio value
146
+ self.previous_portfolio_value = self.portfolio_value
147
+ self.portfolio_value = self._calculate_portfolio_value()
148
+ self.total_return = (self.portfolio_value - self.initial_balance) / self.initial_balance
149
+
150
+ # Calculate reward
151
+ reward = self._calculate_reward()
152
+
153
+ # Move to next step
154
+ self.current_step += 1
155
+
156
+ # Check if episode is done
157
+ done = self.current_step >= len(self.data) - 1
158
+
159
+ # Get observation
160
+ if not done:
161
+ observation = self._get_features(self.data.iloc[self.current_step])
162
+ else:
163
+ # Use last available data for final observation
164
+ observation = self._get_features(self.data.iloc[-1])
165
+
166
+ info = {
167
+ 'balance': self.balance,
168
+ 'position': self.position,
169
+ 'portfolio_value': self.portfolio_value,
170
+ 'total_return': self.total_return,
171
+ 'current_price': current_price
172
+ }
173
+
174
+ return observation, reward, done, False, info
175
+
176
+ def reset(self, seed: Optional[int] = None) -> Tuple[np.ndarray, Dict]:
177
+ """Reset the environment"""
178
+ super().reset(seed=seed)
179
+
180
+ self.current_step = 0
181
+ self.balance = self.initial_balance
182
+ self.position = 0
183
+ self.portfolio_value = self.initial_balance
184
+ self.previous_portfolio_value = self.initial_balance
185
+ self.total_return = 0.0
186
+
187
+ observation = self._get_features(self.data.iloc[self.current_step])
188
+ info = {
189
+ 'balance': self.balance,
190
+ 'position': self.position,
191
+ 'portfolio_value': self.portfolio_value,
192
+ 'total_return': self.total_return
193
+ }
194
+
195
+ return observation, info
196
+
197
+
198
+ class FinRLAgent:
199
+ """
200
+ FinRL-based reinforcement learning agent for algorithmic trading
201
+ """
202
+
203
+ def __init__(self, config: FinRLConfig):
204
+ self.config = config
205
+ self.model = None
206
+ self.env = None
207
+ self.eval_env = None
208
+ self.callback = None
209
+
210
+ logger.info(f"Initializing FinRL agent with algorithm: {config.algorithm}")
211
+
212
+ def create_environment(self, data: pd.DataFrame, initial_balance: float = 100000) -> TradingEnvironment:
213
+ """Create trading environment from market data"""
214
+ return TradingEnvironment(
215
+ data=data,
216
+ initial_balance=initial_balance,
217
+ transaction_fee=0.001,
218
+ max_position=100
219
+ )
220
+
221
+ def prepare_data(self, data: pd.DataFrame) -> pd.DataFrame:
222
+ """Prepare data with technical indicators for FinRL"""
223
+ df = data.copy()
224
+
225
+ # Add technical indicators if not present
226
+ if 'sma_20' not in df.columns:
227
+ df['sma_20'] = df['close'].rolling(window=20).mean()
228
+ if 'sma_50' not in df.columns:
229
+ df['sma_50'] = df['close'].rolling(window=50).mean()
230
+ if 'rsi' not in df.columns:
231
+ df['rsi'] = self._calculate_rsi(df['close'])
232
+ if 'bb_upper' not in df.columns or 'bb_lower' not in df.columns:
233
+ bb_upper, bb_lower = self._calculate_bollinger_bands(df['close'])
234
+ df['bb_upper'] = bb_upper
235
+ df['bb_lower'] = bb_lower
236
+ if 'macd' not in df.columns:
237
+ df['macd'] = self._calculate_macd(df['close'])
238
+
239
+ # Fill NaN values
240
+ df = df.fillna(method='bfill').fillna(0)
241
+
242
+ return df
243
+
244
+ def _calculate_rsi(self, prices: pd.Series, period: int = 14) -> pd.Series:
245
+ """Calculate RSI indicator"""
246
+ delta = prices.diff()
247
+ gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
248
+ loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
249
+ rs = gain / loss
250
+ rsi = 100 - (100 / (1 + rs))
251
+ return rsi
252
+
253
+ def _calculate_bollinger_bands(self, prices: pd.Series, period: int = 20, std_dev: int = 2) -> Tuple[pd.Series, pd.Series]:
254
+ """Calculate Bollinger Bands"""
255
+ sma = prices.rolling(window=period).mean()
256
+ std = prices.rolling(window=period).std()
257
+ upper_band = sma + (std * std_dev)
258
+ lower_band = sma - (std * std_dev)
259
+ return upper_band, lower_band
260
+
261
+ def _calculate_macd(self, prices: pd.Series, fast: int = 12, slow: int = 26, signal: int = 9) -> pd.Series:
262
+ """Calculate MACD indicator"""
263
+ ema_fast = prices.ewm(span=fast).mean()
264
+ ema_slow = prices.ewm(span=slow).mean()
265
+ macd_line = ema_fast - ema_slow
266
+ return macd_line
267
+
268
+ def train(self, data: pd.DataFrame, total_timesteps: int = 100000,
269
+ eval_freq: int = 10000, eval_data: Optional[pd.DataFrame] = None) -> Dict[str, Any]:
270
+ """Train the FinRL agent"""
271
+
272
+ logger.info("Starting FinRL agent training")
273
+
274
+ # Prepare data
275
+ train_data = self.prepare_data(data)
276
+
277
+ # Create training environment
278
+ self.env = DummyVecEnv([lambda: self.create_environment(train_data)])
279
+
280
+ # Create evaluation environment if provided
281
+ if eval_data is not None:
282
+ eval_data = self.prepare_data(eval_data)
283
+ self.eval_env = DummyVecEnv([lambda: self.create_environment(eval_data)])
284
+ self.callback = EvalCallback(
285
+ self.eval_env,
286
+ best_model_save_path="models/finrl_best/",
287
+ log_path="logs/finrl_eval/",
288
+ eval_freq=eval_freq,
289
+ deterministic=True,
290
+ render=False
291
+ )
292
+
293
+ # Initialize model based on algorithm
294
+ if self.config.algorithm == "PPO":
295
+ self.model = PPO(
296
+ "MlpPolicy",
297
+ self.env,
298
+ learning_rate=self.config.learning_rate,
299
+ batch_size=self.config.batch_size,
300
+ gamma=self.config.gamma,
301
+ verbose=self.config.verbose,
302
+ tensorboard_log=self.config.tensorboard_log
303
+ )
304
+ elif self.config.algorithm == "A2C":
305
+ self.model = A2C(
306
+ "MlpPolicy",
307
+ self.env,
308
+ learning_rate=self.config.learning_rate,
309
+ gamma=self.config.gamma,
310
+ verbose=self.config.verbose,
311
+ tensorboard_log=self.config.tensorboard_log
312
+ )
313
+ elif self.config.algorithm == "DDPG":
314
+ self.model = DDPG(
315
+ "MlpPolicy",
316
+ self.env,
317
+ learning_rate=self.config.learning_rate,
318
+ buffer_size=self.config.buffer_size,
319
+ learning_starts=self.config.learning_starts,
320
+ gamma=self.config.gamma,
321
+ tau=self.config.tau,
322
+ train_freq=self.config.train_freq,
323
+ gradient_steps=self.config.gradient_steps,
324
+ verbose=self.config.verbose,
325
+ tensorboard_log=self.config.tensorboard_log
326
+ )
327
+ elif self.config.algorithm == "TD3":
328
+ self.model = TD3(
329
+ "MlpPolicy",
330
+ self.env,
331
+ learning_rate=self.config.learning_rate,
332
+ buffer_size=self.config.buffer_size,
333
+ learning_starts=self.config.learning_starts,
334
+ gamma=self.config.gamma,
335
+ tau=self.config.tau,
336
+ train_freq=self.config.train_freq,
337
+ gradient_steps=self.config.gradient_steps,
338
+ target_update_interval=self.config.target_update_interval,
339
+ verbose=self.config.verbose,
340
+ tensorboard_log=self.config.tensorboard_log
341
+ )
342
+ else:
343
+ raise ValueError(f"Unsupported algorithm: {self.config.algorithm}")
344
+
345
+ # Train the model
346
+ callbacks = [self.callback] if self.callback else None
347
+ self.model.learn(
348
+ total_timesteps=total_timesteps,
349
+ callback=callbacks
350
+ )
351
+
352
+ logger.info("FinRL agent training completed")
353
+
354
+ return {
355
+ 'algorithm': self.config.algorithm,
356
+ 'total_timesteps': total_timesteps,
357
+ 'model_path': f"models/finrl_{self.config.algorithm.lower()}"
358
+ }
359
+
360
+ def predict(self, data: pd.DataFrame) -> List[int]:
361
+ """Generate trading predictions using the trained model"""
362
+ if self.model is None:
363
+ raise ValueError("Model not trained. Call train() first.")
364
+
365
+ # Prepare data
366
+ test_data = self.prepare_data(data)
367
+
368
+ # Create test environment
369
+ test_env = self.create_environment(test_data)
370
+
371
+ predictions = []
372
+ obs, _ = test_env.reset()
373
+
374
+ done = False
375
+ while not done:
376
+ action, _ = self.model.predict(obs, deterministic=True)
377
+ predictions.append(action)
378
+ obs, _, done, _, _ = test_env.step(action)
379
+
380
+ return predictions
381
+
382
+ def evaluate(self, data: pd.DataFrame) -> Dict[str, float]:
383
+ """Evaluate the trained model on test data"""
384
+ if self.model is None:
385
+ raise ValueError("Model not trained. Call train() first.")
386
+
387
+ # Prepare data
388
+ test_data = self.prepare_data(data)
389
+
390
+ # Create test environment
391
+ test_env = self.create_environment(test_data)
392
+
393
+ obs, _ = test_env.reset()
394
+ done = False
395
+ total_reward = 0
396
+ steps = 0
397
+
398
+ while not done:
399
+ action, _ = self.model.predict(obs, deterministic=True)
400
+ obs, reward, done, _, info = test_env.step(action)
401
+ total_reward += reward
402
+ steps += 1
403
+
404
+ # Calculate metrics
405
+ final_portfolio_value = info['portfolio_value']
406
+ initial_balance = test_env.initial_balance
407
+ total_return = (final_portfolio_value - initial_balance) / initial_balance
408
+
409
+ return {
410
+ 'total_reward': total_reward,
411
+ 'total_return': total_return,
412
+ 'final_portfolio_value': final_portfolio_value,
413
+ 'steps': steps,
414
+ 'sharpe_ratio': total_reward / steps if steps > 0 else 0
415
+ }
416
+
417
+ def save_model(self, path: str):
418
+ """Save the trained model"""
419
+ if self.model is None:
420
+ raise ValueError("No model to save. Train the model first.")
421
+
422
+ self.model.save(path)
423
+ logger.info(f"Model saved to {path}")
424
+
425
+ def load_model(self, path: str):
426
+ """Load a trained model"""
427
+ if self.config.algorithm == "PPO":
428
+ self.model = PPO.load(path)
429
+ elif self.config.algorithm == "A2C":
430
+ self.model = A2C.load(path)
431
+ elif self.config.algorithm == "DDPG":
432
+ self.model = DDPG.load(path)
433
+ elif self.config.algorithm == "TD3":
434
+ self.model = TD3.load(path)
435
+ else:
436
+ raise ValueError(f"Unsupported algorithm: {self.config.algorithm}")
437
+
438
+ logger.info(f"Model loaded from {path}")
439
+
440
+
441
+ def create_finrl_agent_from_config(config_path: str) -> FinRLAgent:
442
+ """Create FinRL agent from configuration file"""
443
+ with open(config_path, 'r') as file:
444
+ config_data = yaml.safe_load(file)
445
+
446
+ finrl_config = FinRLConfig(**config_data.get('finrl', {}))
447
+ return FinRLAgent(finrl_config)
config.yaml CHANGED
@@ -33,3 +33,30 @@ logging:
33
  enable_file: true
34
  max_file_size_mb: 10
35
  backup_count: 5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  enable_file: true
34
  max_file_size_mb: 10
35
  backup_count: 5
36
+
37
+ # FinRL configuration
38
+ finrl:
39
+ algorithm: 'PPO' # PPO, A2C, DDPG, TD3
40
+ learning_rate: 0.0003
41
+ batch_size: 64
42
+ buffer_size: 1000000
43
+ learning_starts: 100
44
+ gamma: 0.99
45
+ tau: 0.005
46
+ train_freq: 1
47
+ gradient_steps: 1
48
+ target_update_interval: 1
49
+ exploration_fraction: 0.1
50
+ exploration_initial_eps: 1.0
51
+ exploration_final_eps: 0.05
52
+ max_grad_norm: 10.0
53
+ verbose: 1
54
+ tensorboard_log: 'logs/finrl_tensorboard'
55
+ training:
56
+ total_timesteps: 100000
57
+ eval_freq: 10000
58
+ save_best_model: true
59
+ model_save_path: 'models/finrl_best/'
60
+ inference:
61
+ use_trained_model: false
62
+ model_path: 'models/finrl_best/best_model'
finrl_demo.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ FinRL Demo Script
4
+
5
+ This script demonstrates the integration of FinRL with the algorithmic trading system.
6
+ It shows how to train a reinforcement learning agent and use it for trading decisions.
7
+ """
8
+
9
+ import os
10
+ import sys
11
+ import yaml
12
+ import pandas as pd
13
+ import numpy as np
14
+ import matplotlib.pyplot as plt
15
+ import seaborn as sns
16
+ from datetime import datetime, timedelta
17
+ import logging
18
+
19
+ # Add the project root to the path
20
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
21
+
22
+ from agentic_ai_system.finrl_agent import FinRLAgent, FinRLConfig, create_finrl_agent_from_config
23
+ from agentic_ai_system.synthetic_data_generator import SyntheticDataGenerator
24
+ from agentic_ai_system.logger_config import setup_logging
25
+
26
+ # Setup logging
27
+ setup_logging()
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ def load_config(config_path: str = 'config.yaml') -> dict:
32
+ """Load configuration from YAML file"""
33
+ with open(config_path, 'r') as file:
34
+ return yaml.safe_load(file)
35
+
36
+
37
+ def generate_training_data(config: dict) -> pd.DataFrame:
38
+ """Generate synthetic data for training"""
39
+ logger.info("Generating synthetic training data")
40
+
41
+ generator = SyntheticDataGenerator(config)
42
+
43
+ # Generate training data (longer period)
44
+ train_data = generator.generate_ohlcv_data(
45
+ symbol='AAPL',
46
+ start_date='2023-01-01',
47
+ end_date='2023-12-31',
48
+ frequency='1H'
49
+ )
50
+
51
+ # Add technical indicators
52
+ train_data['sma_20'] = train_data['close'].rolling(window=20).mean()
53
+ train_data['sma_50'] = train_data['close'].rolling(window=50).mean()
54
+ train_data['rsi'] = calculate_rsi(train_data['close'])
55
+ bb_upper, bb_lower = calculate_bollinger_bands(train_data['close'])
56
+ train_data['bb_upper'] = bb_upper
57
+ train_data['bb_lower'] = bb_lower
58
+ train_data['macd'] = calculate_macd(train_data['close'])
59
+
60
+ # Fill NaN values
61
+ train_data = train_data.fillna(method='bfill').fillna(0)
62
+
63
+ logger.info(f"Generated {len(train_data)} training samples")
64
+ return train_data
65
+
66
+
67
+ def generate_test_data(config: dict) -> pd.DataFrame:
68
+ """Generate synthetic data for testing"""
69
+ logger.info("Generating synthetic test data")
70
+
71
+ generator = SyntheticDataGenerator(config)
72
+
73
+ # Generate test data (shorter period)
74
+ test_data = generator.generate_ohlcv_data(
75
+ symbol='AAPL',
76
+ start_date='2024-01-01',
77
+ end_date='2024-03-31',
78
+ frequency='1H'
79
+ )
80
+
81
+ # Add technical indicators
82
+ test_data['sma_20'] = test_data['close'].rolling(window=20).mean()
83
+ test_data['sma_50'] = test_data['close'].rolling(window=50).mean()
84
+ test_data['rsi'] = calculate_rsi(test_data['close'])
85
+ bb_upper, bb_lower = calculate_bollinger_bands(test_data['close'])
86
+ test_data['bb_upper'] = bb_upper
87
+ test_data['bb_lower'] = bb_lower
88
+ test_data['macd'] = calculate_macd(test_data['close'])
89
+
90
+ # Fill NaN values
91
+ test_data = test_data.fillna(method='bfill').fillna(0)
92
+
93
+ logger.info(f"Generated {len(test_data)} test samples")
94
+ return test_data
95
+
96
+
97
+ def calculate_rsi(prices: pd.Series, period: int = 14) -> pd.Series:
98
+ """Calculate RSI indicator"""
99
+ delta = prices.diff()
100
+ gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
101
+ loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
102
+ rs = gain / loss
103
+ rsi = 100 - (100 / (1 + rs))
104
+ return rsi
105
+
106
+
107
+ def calculate_bollinger_bands(prices: pd.Series, period: int = 20, std_dev: int = 2):
108
+ """Calculate Bollinger Bands"""
109
+ sma = prices.rolling(window=period).mean()
110
+ std = prices.rolling(window=period).std()
111
+ upper_band = sma + (std * std_dev)
112
+ lower_band = sma - (std * std_dev)
113
+ return upper_band, lower_band
114
+
115
+
116
+ def calculate_macd(prices: pd.Series, fast: int = 12, slow: int = 26, signal: int = 9) -> pd.Series:
117
+ """Calculate MACD indicator"""
118
+ ema_fast = prices.ewm(span=fast).mean()
119
+ ema_slow = prices.ewm(span=slow).mean()
120
+ macd_line = ema_fast - ema_slow
121
+ return macd_line
122
+
123
+
124
+ def train_finrl_agent(config: dict, train_data: pd.DataFrame, test_data: pd.DataFrame) -> FinRLAgent:
125
+ """Train the FinRL agent"""
126
+ logger.info("Starting FinRL agent training")
127
+
128
+ # Create FinRL agent
129
+ finrl_config = FinRLConfig(**config['finrl'])
130
+ agent = FinRLAgent(finrl_config)
131
+
132
+ # Train the agent
133
+ training_result = agent.train(
134
+ data=train_data,
135
+ total_timesteps=config['finrl']['training']['total_timesteps'],
136
+ eval_freq=config['finrl']['training']['eval_freq'],
137
+ eval_data=test_data
138
+ )
139
+
140
+ logger.info(f"Training completed: {training_result}")
141
+
142
+ # Save the model
143
+ if config['finrl']['training']['save_best_model']:
144
+ model_path = config['finrl']['training']['model_save_path']
145
+ os.makedirs(os.path.dirname(model_path), exist_ok=True)
146
+ agent.save_model(model_path)
147
+
148
+ return agent
149
+
150
+
151
+ def evaluate_agent(agent: FinRLAgent, test_data: pd.DataFrame) -> dict:
152
+ """Evaluate the trained agent"""
153
+ logger.info("Evaluating FinRL agent")
154
+
155
+ # Evaluate on test data
156
+ evaluation_results = agent.evaluate(test_data)
157
+
158
+ logger.info(f"Evaluation results: {evaluation_results}")
159
+
160
+ return evaluation_results
161
+
162
+
163
+ def generate_predictions(agent: FinRLAgent, test_data: pd.DataFrame) -> list:
164
+ """Generate trading predictions"""
165
+ logger.info("Generating trading predictions")
166
+
167
+ predictions = agent.predict(test_data)
168
+
169
+ logger.info(f"Generated {len(predictions)} predictions")
170
+
171
+ return predictions
172
+
173
+
174
+ def plot_results(test_data: pd.DataFrame, predictions: list, evaluation_results: dict):
175
+ """Plot trading results"""
176
+ logger.info("Creating visualization plots")
177
+
178
+ # Create figure with subplots
179
+ fig, axes = plt.subplots(3, 1, figsize=(15, 12))
180
+
181
+ # Plot 1: Price and predictions
182
+ axes[0].plot(test_data.index, test_data['close'], label='Close Price', alpha=0.7)
183
+
184
+ # Mark buy/sell signals
185
+ buy_signals = [i for i, pred in enumerate(predictions) if pred == 2]
186
+ sell_signals = [i for i, pred in enumerate(predictions) if pred == 0]
187
+
188
+ if buy_signals:
189
+ axes[0].scatter(test_data.index[buy_signals], test_data['close'].iloc[buy_signals],
190
+ color='green', marker='^', s=100, label='Buy Signal', alpha=0.8)
191
+ if sell_signals:
192
+ axes[0].scatter(test_data.index[sell_signals], test_data['close'].iloc[sell_signals],
193
+ color='red', marker='v', s=100, label='Sell Signal', alpha=0.8)
194
+
195
+ axes[0].set_title('Price Action and Trading Signals')
196
+ axes[0].set_ylabel('Price')
197
+ axes[0].legend()
198
+ axes[0].grid(True, alpha=0.3)
199
+
200
+ # Plot 2: Technical indicators
201
+ axes[1].plot(test_data.index, test_data['close'], label='Close Price', alpha=0.7)
202
+ axes[1].plot(test_data.index, test_data['sma_20'], label='SMA 20', alpha=0.7)
203
+ axes[1].plot(test_data.index, test_data['sma_50'], label='SMA 50', alpha=0.7)
204
+ axes[1].plot(test_data.index, test_data['bb_upper'], label='BB Upper', alpha=0.5)
205
+ axes[1].plot(test_data.index, test_data['bb_lower'], label='BB Lower', alpha=0.5)
206
+
207
+ axes[1].set_title('Technical Indicators')
208
+ axes[1].set_ylabel('Price')
209
+ axes[1].legend()
210
+ axes[1].grid(True, alpha=0.3)
211
+
212
+ # Plot 3: RSI
213
+ axes[2].plot(test_data.index, test_data['rsi'], label='RSI', color='purple')
214
+ axes[2].axhline(y=70, color='r', linestyle='--', alpha=0.5, label='Overbought')
215
+ axes[2].axhline(y=30, color='g', linestyle='--', alpha=0.5, label='Oversold')
216
+ axes[2].set_title('RSI Indicator')
217
+ axes[2].set_ylabel('RSI')
218
+ axes[2].set_xlabel('Time')
219
+ axes[2].legend()
220
+ axes[2].grid(True, alpha=0.3)
221
+
222
+ plt.tight_layout()
223
+
224
+ # Save plot
225
+ os.makedirs('plots', exist_ok=True)
226
+ plt.savefig('plots/finrl_trading_results.png', dpi=300, bbox_inches='tight')
227
+ plt.show()
228
+
229
+ logger.info("Plots saved to plots/finrl_trading_results.png")
230
+
231
+
232
+ def print_summary(evaluation_results: dict, predictions: list):
233
+ """Print trading summary"""
234
+ print("\n" + "="*60)
235
+ print("FINRL TRADING SYSTEM SUMMARY")
236
+ print("="*60)
237
+
238
+ print(f"Algorithm: {evaluation_results.get('algorithm', 'Unknown')}")
239
+ print(f"Total Return: {evaluation_results['total_return']:.2%}")
240
+ print(f"Final Portfolio Value: ${evaluation_results['final_portfolio_value']:,.2f}")
241
+ print(f"Total Reward: {evaluation_results['total_reward']:.4f}")
242
+ print(f"Sharpe Ratio: {evaluation_results['sharpe_ratio']:.4f}")
243
+ print(f"Number of Trading Steps: {evaluation_results['steps']}")
244
+
245
+ # Trading statistics
246
+ buy_signals = sum(1 for pred in predictions if pred == 2)
247
+ sell_signals = sum(1 for pred in predictions if pred == 0)
248
+ hold_signals = sum(1 for pred in predictions if pred == 1)
249
+
250
+ print(f"\nTrading Signals:")
251
+ print(f" Buy signals: {buy_signals}")
252
+ print(f" Sell signals: {sell_signals}")
253
+ print(f" Hold signals: {hold_signals}")
254
+ print(f" Total signals: {len(predictions)}")
255
+
256
+ print("\n" + "="*60)
257
+
258
+
259
+ def main():
260
+ """Main function to run the FinRL demo"""
261
+ logger.info("Starting FinRL Demo")
262
+
263
+ try:
264
+ # Load configuration
265
+ config = load_config()
266
+
267
+ # Generate data
268
+ train_data = generate_training_data(config)
269
+ test_data = generate_test_data(config)
270
+
271
+ # Train FinRL agent
272
+ agent = train_finrl_agent(config, train_data, test_data)
273
+
274
+ # Evaluate agent
275
+ evaluation_results = evaluate_agent(agent, test_data)
276
+
277
+ # Generate predictions
278
+ predictions = generate_predictions(agent, test_data)
279
+
280
+ # Create visualizations
281
+ plot_results(test_data, predictions, evaluation_results)
282
+
283
+ # Print summary
284
+ print_summary(evaluation_results, predictions)
285
+
286
+ logger.info("FinRL Demo completed successfully")
287
+
288
+ except Exception as e:
289
+ logger.error(f"Error in FinRL demo: {str(e)}")
290
+ raise
291
+
292
+
293
+ if __name__ == "__main__":
294
+ main()
requirements.txt CHANGED
@@ -7,3 +7,8 @@ pytest
7
  pytest-cov
8
  python-dateutil
9
  scipy
 
 
 
 
 
 
7
  pytest-cov
8
  python-dateutil
9
  scipy
10
+ finrl
11
+ stable-baselines3
12
+ gymnasium
13
+ tensorboard
14
+ torch
tests/test_finrl_agent.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests for FinRL Agent
3
+
4
+ This module contains comprehensive tests for the FinRL agent functionality.
5
+ """
6
+
7
+ import pytest
8
+ import pandas as pd
9
+ import numpy as np
10
+ import yaml
11
+ import tempfile
12
+ import os
13
+ from unittest.mock import Mock, patch
14
+
15
+ # Add the project root to the path
16
+ import sys
17
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
18
+
19
+ from agentic_ai_system.finrl_agent import (
20
+ FinRLAgent,
21
+ FinRLConfig,
22
+ TradingEnvironment,
23
+ create_finrl_agent_from_config
24
+ )
25
+
26
+
27
+ class TestFinRLConfig:
28
+ """Test FinRL configuration"""
29
+
30
+ def test_default_config(self):
31
+ """Test default configuration values"""
32
+ config = FinRLConfig()
33
+
34
+ assert config.algorithm == "PPO"
35
+ assert config.learning_rate == 0.0003
36
+ assert config.batch_size == 64
37
+ assert config.gamma == 0.99
38
+
39
+ def test_custom_config(self):
40
+ """Test custom configuration values"""
41
+ config = FinRLConfig(
42
+ algorithm="A2C",
43
+ learning_rate=0.001,
44
+ batch_size=128
45
+ )
46
+
47
+ assert config.algorithm == "A2C"
48
+ assert config.learning_rate == 0.001
49
+ assert config.batch_size == 128
50
+
51
+
52
+ class TestTradingEnvironment:
53
+ """Test trading environment"""
54
+
55
+ @pytest.fixture
56
+ def sample_data(self):
57
+ """Create sample market data"""
58
+ dates = pd.date_range('2024-01-01', periods=100, freq='1H')
59
+ data = pd.DataFrame({
60
+ 'open': np.random.uniform(100, 200, 100),
61
+ 'high': np.random.uniform(100, 200, 100),
62
+ 'low': np.random.uniform(100, 200, 100),
63
+ 'close': np.random.uniform(100, 200, 100),
64
+ 'volume': np.random.uniform(1000, 10000, 100),
65
+ 'sma_20': np.random.uniform(100, 200, 100),
66
+ 'sma_50': np.random.uniform(100, 200, 100),
67
+ 'rsi': np.random.uniform(0, 100, 100),
68
+ 'bb_upper': np.random.uniform(100, 200, 100),
69
+ 'bb_lower': np.random.uniform(100, 200, 100),
70
+ 'macd': np.random.uniform(-10, 10, 100)
71
+ }, index=dates)
72
+ return data
73
+
74
+ def test_environment_initialization(self, sample_data):
75
+ """Test environment initialization"""
76
+ env = TradingEnvironment(sample_data)
77
+
78
+ assert env.initial_balance == 100000
79
+ assert env.transaction_fee == 0.001
80
+ assert env.max_position == 100
81
+ assert env.action_space.n == 3
82
+ assert len(env.observation_space.shape) == 1
83
+
84
+ def test_environment_reset(self, sample_data):
85
+ """Test environment reset"""
86
+ env = TradingEnvironment(sample_data)
87
+ obs, info = env.reset()
88
+
89
+ assert env.current_step == 0
90
+ assert env.balance == env.initial_balance
91
+ assert env.position == 0
92
+ assert env.portfolio_value == env.initial_balance
93
+ assert isinstance(obs, np.ndarray)
94
+ assert isinstance(info, dict)
95
+
96
+ def test_environment_step(self, sample_data):
97
+ """Test environment step"""
98
+ env = TradingEnvironment(sample_data)
99
+ obs, info = env.reset()
100
+
101
+ # Test hold action
102
+ obs, reward, done, truncated, info = env.step(1)
103
+
104
+ assert isinstance(obs, np.ndarray)
105
+ assert isinstance(reward, float)
106
+ assert isinstance(done, bool)
107
+ assert isinstance(truncated, bool)
108
+ assert isinstance(info, dict)
109
+ assert env.current_step == 1
110
+
111
+ def test_buy_action(self, sample_data):
112
+ """Test buy action"""
113
+ env = TradingEnvironment(sample_data, initial_balance=10000)
114
+ obs, info = env.reset()
115
+
116
+ initial_balance = env.balance
117
+ initial_position = env.position
118
+
119
+ # Buy action
120
+ obs, reward, done, truncated, info = env.step(2)
121
+
122
+ assert env.position > initial_position
123
+ assert env.balance < initial_balance
124
+
125
+ def test_sell_action(self, sample_data):
126
+ """Test sell action"""
127
+ env = TradingEnvironment(sample_data, initial_balance=10000)
128
+ obs, info = env.reset()
129
+
130
+ # First buy some shares
131
+ obs, reward, done, truncated, info = env.step(2)
132
+ initial_position = env.position
133
+ initial_balance = env.balance
134
+
135
+ # Then sell
136
+ obs, reward, done, truncated, info = env.step(0)
137
+
138
+ assert env.position < initial_position
139
+ assert env.balance > initial_balance
140
+
141
+ def test_portfolio_value_calculation(self, sample_data):
142
+ """Test portfolio value calculation"""
143
+ env = TradingEnvironment(sample_data)
144
+ obs, info = env.reset()
145
+
146
+ # Buy some shares
147
+ obs, reward, done, truncated, info = env.step(2)
148
+
149
+ expected_value = env.balance + (env.position * sample_data.iloc[env.current_step]['close'])
150
+ assert abs(env.portfolio_value - expected_value) < 1e-6
151
+
152
+
153
+ class TestFinRLAgent:
154
+ """Test FinRL agent"""
155
+
156
+ @pytest.fixture
157
+ def sample_data(self):
158
+ """Create sample market data"""
159
+ dates = pd.date_range('2024-01-01', periods=100, freq='1H')
160
+ data = pd.DataFrame({
161
+ 'open': np.random.uniform(100, 200, 100),
162
+ 'high': np.random.uniform(100, 200, 100),
163
+ 'low': np.random.uniform(100, 200, 100),
164
+ 'close': np.random.uniform(100, 200, 100),
165
+ 'volume': np.random.uniform(1000, 10000, 100)
166
+ }, index=dates)
167
+ return data
168
+
169
+ @pytest.fixture
170
+ def finrl_config(self):
171
+ """Create FinRL configuration"""
172
+ return FinRLConfig(
173
+ algorithm="PPO",
174
+ learning_rate=0.0003,
175
+ batch_size=32,
176
+ total_timesteps=1000
177
+ )
178
+
179
+ def test_agent_initialization(self, finrl_config):
180
+ """Test agent initialization"""
181
+ agent = FinRLAgent(finrl_config)
182
+
183
+ assert agent.config == finrl_config
184
+ assert agent.model is None
185
+ assert agent.env is None
186
+
187
+ def test_prepare_data(self, finrl_config, sample_data):
188
+ """Test data preparation"""
189
+ agent = FinRLAgent(finrl_config)
190
+ prepared_data = agent.prepare_data(sample_data)
191
+
192
+ # Check that technical indicators were added
193
+ assert 'sma_20' in prepared_data.columns
194
+ assert 'sma_50' in prepared_data.columns
195
+ assert 'rsi' in prepared_data.columns
196
+ assert 'bb_upper' in prepared_data.columns
197
+ assert 'bb_lower' in prepared_data.columns
198
+ assert 'macd' in prepared_data.columns
199
+
200
+ # Check that no NaN values remain
201
+ assert not prepared_data.isnull().any().any()
202
+
203
+ def test_create_environment(self, finrl_config, sample_data):
204
+ """Test environment creation"""
205
+ agent = FinRLAgent(finrl_config)
206
+ env = agent.create_environment(sample_data)
207
+
208
+ assert isinstance(env, TradingEnvironment)
209
+ assert env.data.equals(sample_data)
210
+
211
+ def test_technical_indicators_calculation(self, finrl_config):
212
+ """Test technical indicators calculation"""
213
+ agent = FinRLAgent(finrl_config)
214
+
215
+ # Test RSI calculation
216
+ prices = pd.Series([100, 101, 99, 102, 98, 103, 97, 104, 96, 105])
217
+ rsi = agent._calculate_rsi(prices, period=3)
218
+ assert len(rsi) == len(prices)
219
+ assert not rsi.isnull().all()
220
+
221
+ # Test Bollinger Bands calculation
222
+ bb_upper, bb_lower = agent._calculate_bollinger_bands(prices, period=3)
223
+ assert len(bb_upper) == len(prices)
224
+ assert len(bb_lower) == len(prices)
225
+ assert (bb_upper >= bb_lower).all()
226
+
227
+ # Test MACD calculation
228
+ macd = agent._calculate_macd(prices)
229
+ assert len(macd) == len(prices)
230
+
231
+ @patch('agentic_ai_system.finrl_agent.PPO')
232
+ def test_training_ppo(self, mock_ppo, finrl_config, sample_data):
233
+ """Test PPO training"""
234
+ # Mock the PPO model
235
+ mock_model = Mock()
236
+ mock_ppo.return_value = mock_model
237
+
238
+ agent = FinRLAgent(finrl_config)
239
+ result = agent.train(sample_data, total_timesteps=100)
240
+
241
+ assert result['algorithm'] == 'PPO'
242
+ assert result['total_timesteps'] == 100
243
+ mock_model.learn.assert_called_once()
244
+
245
+ @patch('agentic_ai_system.finrl_agent.A2C')
246
+ def test_training_a2c(self, mock_a2c):
247
+ """Test A2C training"""
248
+ config = FinRLConfig(algorithm="A2C")
249
+ mock_model = Mock()
250
+ mock_a2c.return_value = mock_model
251
+
252
+ agent = FinRLAgent(config)
253
+ sample_data = pd.DataFrame({
254
+ 'open': [100, 101, 102],
255
+ 'high': [101, 102, 103],
256
+ 'low': [99, 100, 101],
257
+ 'close': [100, 101, 102],
258
+ 'volume': [1000, 1100, 1200]
259
+ })
260
+
261
+ result = agent.train(sample_data, total_timesteps=100)
262
+
263
+ assert result['algorithm'] == 'A2C'
264
+ mock_model.learn.assert_called_once()
265
+
266
+ def test_invalid_algorithm(self):
267
+ """Test invalid algorithm handling"""
268
+ config = FinRLConfig(algorithm="INVALID")
269
+ agent = FinRLAgent(config)
270
+ sample_data = pd.DataFrame({
271
+ 'open': [100, 101, 102],
272
+ 'high': [101, 102, 103],
273
+ 'low': [99, 100, 101],
274
+ 'close': [100, 101, 102],
275
+ 'volume': [1000, 1100, 1200]
276
+ })
277
+
278
+ with pytest.raises(ValueError, match="Unsupported algorithm"):
279
+ agent.train(sample_data, total_timesteps=100)
280
+
281
+ def test_predict_without_training(self, finrl_config, sample_data):
282
+ """Test prediction without training"""
283
+ agent = FinRLAgent(finrl_config)
284
+
285
+ with pytest.raises(ValueError, match="Model not trained"):
286
+ agent.predict(sample_data)
287
+
288
+ def test_evaluate_without_training(self, finrl_config, sample_data):
289
+ """Test evaluation without training"""
290
+ agent = FinRLAgent(finrl_config)
291
+
292
+ with pytest.raises(ValueError, match="Model not trained"):
293
+ agent.evaluate(sample_data)
294
+
295
+ @patch('agentic_ai_system.finrl_agent.PPO')
296
+ def test_save_and_load_model(self, mock_ppo, finrl_config, sample_data):
297
+ """Test model saving and loading"""
298
+ # Mock the PPO model
299
+ mock_model = Mock()
300
+ mock_ppo.return_value = mock_model
301
+ mock_ppo.load.return_value = mock_model
302
+
303
+ agent = FinRLAgent(finrl_config)
304
+
305
+ # Train the agent
306
+ agent.train(sample_data, total_timesteps=100)
307
+
308
+ # Test saving
309
+ with tempfile.NamedTemporaryFile(suffix='.zip', delete=False) as tmp_file:
310
+ agent.save_model(tmp_file.name)
311
+ mock_model.save.assert_called_once_with(tmp_file.name)
312
+
313
+ # Test loading
314
+ agent.load_model(tmp_file.name)
315
+ mock_ppo.load.assert_called_once_with(tmp_file.name)
316
+
317
+ # Clean up
318
+ os.unlink(tmp_file.name)
319
+
320
+
321
+ class TestFinRLIntegration:
322
+ """Test FinRL integration with configuration"""
323
+
324
+ def test_create_agent_from_config(self):
325
+ """Test creating agent from configuration file"""
326
+ config_data = {
327
+ 'finrl': {
328
+ 'algorithm': 'PPO',
329
+ 'learning_rate': 0.001,
330
+ 'batch_size': 128,
331
+ 'gamma': 0.95
332
+ }
333
+ }
334
+
335
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as tmp_file:
336
+ yaml.dump(config_data, tmp_file)
337
+ tmp_file_path = tmp_file.name
338
+
339
+ try:
340
+ agent = create_finrl_agent_from_config(tmp_file_path)
341
+
342
+ assert agent.config.algorithm == 'PPO'
343
+ assert agent.config.learning_rate == 0.001
344
+ assert agent.config.batch_size == 128
345
+ assert agent.config.gamma == 0.95
346
+ finally:
347
+ os.unlink(tmp_file_path)
348
+
349
+ def test_create_agent_from_config_missing_finrl(self):
350
+ """Test creating agent from config without finrl section"""
351
+ config_data = {
352
+ 'trading': {
353
+ 'symbol': 'AAPL',
354
+ 'capital': 100000
355
+ }
356
+ }
357
+
358
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as tmp_file:
359
+ yaml.dump(config_data, tmp_file)
360
+ tmp_file_path = tmp_file.name
361
+
362
+ try:
363
+ agent = create_finrl_agent_from_config(tmp_file_path)
364
+
365
+ # Should use default values
366
+ assert agent.config.algorithm == 'PPO'
367
+ assert agent.config.learning_rate == 0.0003
368
+ finally:
369
+ os.unlink(tmp_file_path)
370
+
371
+
372
+ if __name__ == "__main__":
373
+ pytest.main([__file__])