import random from copy import deepcopy from datetime import datetime, time, timedelta from typing import Dict, List, Optional, Union from .long_context import ( AUTOMOBILE_EXTENSION, MA_5_EXTENSION, MA_20_EXTENSION, ORDER_DETAIL_EXTENSION, TECHNOLOGY_EXTENSION, TRANSACTION_HISTORY_EXTENSION, WATCH_LIST_EXTENSION, ) CURRENT_TIME = datetime(2024, 9, 1, 10, 30) DEFAULT_STATE = { "orders": { 12345: { "id": 12345, "order_type": "Buy", "symbol": "AAPL", "price": 210.65, "amount": 10, "status": "Completed", }, 12446: { "id": 12446, "order_type": "Sell", "symbol": "GOOG", "price": 2840.56, "amount": 5, "status": "Pending", }, }, "account_info": { "account_id": 12345, "balance": 10000.0, "binding_card": 1974202140965533, }, "authenticated": False, "market_status": "Closed", "order_counter": 12446, "stocks": { "AAPL": { "price": 227.16, "percent_change": 0.17, "volume": 2.552, "MA(5)": 227.11, "MA(20)": 227.09, }, "GOOG": { "price": 2840.34, "percent_change": 0.24, "volume": 1.123, "MA(5)": 2835.67, "MA(20)": 2842.15, }, "TSLA": { "price": 667.92, "percent_change": -0.12, "volume": 1.654, "MA(5)": 671.15, "MA(20)": 668.20, }, "MSFT": { "price": 310.23, "percent_change": 0.09, "volume": 3.234, "MA(5)": 309.88, "MA(20)": 310.11, }, "NVDA": { "price": 220.34, "percent_change": 0.34, "volume": 1.234, "MA(5)": 220.45, "MA(20)": 220.67, }, "ALPH": { "price": 1320.45, "percent_change": -0.08, "volume": 1.567, "MA(5)": 1321.12, "MA(20)": 1325.78, }, "OMEG": { "price": 457.23, "percent_change": 0.12, "volume": 2.345, "MA(5)": 456.78, "MA(20)": 458.12, }, "QUAS": { "price": 725.89, "percent_change": -0.03, "volume": 1.789, "MA(5)": 726.45, "MA(20)": 728.00, }, "NEPT": { "price": 88.34, "percent_change": 0.19, "volume": 0.654, "MA(5)": 88.21, "MA(20)": 88.67, }, "SYNX": { "price": 345.67, "percent_change": 0.11, "volume": 2.112, "MA(5)": 345.34, "MA(20)": 346.12, }, "ZETA": { "price": 22.09, "percent_change": -0.05, "volume": 0.789, "MA(5)": 22.12, "MA(20)": 22.34, }, }, "watch_list": ["NVDA"], "transaction_history": [], "random_seed": 1053520, } class TradingBot: """ A class representing a trading bot for executing stock trades and managing a trading account. Attributes: orders (Dict[int, Dict[str, Union[str, float, int]]]): A dictionary of orders for purchasing and selling of stock, keyed by order ID. account_info (Dict[str, Union[int, float]]): Information about the trading account. authenticated (bool): Whether the user is currently authenticated. market_status (str): The current status of the market ('Open' or 'Closed'). order_counter (int): A counter for generating unique order IDs. stocks (Dict[str, Dict[str, Union[float, int]]]): Information about various stocks. watch_list (List[str]): A list of stock symbols being watched. transaction_history (List[Dict[str, Union[str, float, int]]]): A history of trading account related transactions. """ def __init__(self): """ Initialize the TradingBot instance. """ self.orders: Dict[int, Dict[str, Union[str, float, int]]] self.account_info: Dict[str, Union[int, float]] self.authenticated: bool self.market_status: str self.order_counter: int self.stocks: Dict[str, Dict[str, Union[float, int]]] self.watch_list: List[str] self.transaction_history: List[Dict[str, Union[str, float, int]]] self._api_description = "This tool belongs to the trading system, which allows users to trade stocks, manage their account, and view stock information." def _load_scenario(self, scenario: dict, long_context=False) -> None: """ Load a scenario into the TradingBot. Args: scenario (dict): A scenario dictionary containing data to load. """ DEFAULT_STATE_COPY = deepcopy(DEFAULT_STATE) self.orders = scenario.get("orders", DEFAULT_STATE_COPY["orders"]) # Convert all string keys that can be interpreted as integers to integer keys self.orders = { int(k) if isinstance(k, str) and k.isdigit() else k: v for k, v in self.orders.items() } self.account_info = scenario.get("account_info", DEFAULT_STATE_COPY["account_info"]) self.authenticated = scenario.get( "authenticated", DEFAULT_STATE_COPY["authenticated"] ) self.market_status = scenario.get( "market_status", DEFAULT_STATE_COPY["market_status"] ) self.order_counter = scenario.get( "order_counter", DEFAULT_STATE_COPY["order_counter"] ) # Start counter from the next order ID self.stocks = scenario.get("stocks", DEFAULT_STATE_COPY["stocks"]) self.watch_list = scenario.get("watch_list", DEFAULT_STATE_COPY["watch_list"]) self.transaction_history = scenario.get( "transaction_history", DEFAULT_STATE_COPY["transaction_history"] ) self.long_context = long_context self._random = random.Random( (scenario.get("random_seed", DEFAULT_STATE_COPY["random_seed"])) ) def _generate_transaction_timestamp(self) -> str: """ Generate a timestamp for a transaction. Returns: timestamp (str): A formatted timestamp string. """ # Define the start and end dates for the range start_date = CURRENT_TIME end_date = CURRENT_TIME + timedelta(days=1) start_timestamp = int(start_date.timestamp()) end_timestamp = int(end_date.timestamp()) # Generate a random timestamp within the range random_timestamp = self._random.randint(start_timestamp, end_timestamp) # Convert the random timestamp to a datetime object random_date = datetime.fromtimestamp(random_timestamp) return random_date.strftime("%Y-%m-%d %H:%M:%S") def get_current_time(self) -> Dict[str, str]: """ Get the current time. Returns: current_time (str): Current time in HH:MM AM/PM format. """ return {"current_time": CURRENT_TIME.strftime("%I:%M %p")} def update_market_status(self, current_time_str: str) -> Dict[str, str]: """ Update the market status based on the current time. Args: current_time_str (str): Current time in HH:MM AM/PM format. Returns: status (str): Status of the market. [Enum]: ["Open", "Closed"] """ market_open_time = time(9, 30) # Market opens at 9:30 AM market_close_time = time(16, 0) # Market closes at 4:00 PM current_time = datetime.strptime(current_time_str, "%I:%M %p").time() if market_open_time <= current_time <= market_close_time: self.market_status = "Open" return {"status": "Open"} else: self.market_status = "Closed" return {"status": "Closed"} def get_symbol_by_name(self, name: str) -> Dict[str, str]: """ Get the symbol of a stock by company name. Args: name (str): Name of the company. Returns: symbol (str): Symbol of the stock or "Stock not found" if not available. """ symbol_map = { "Apple": "AAPL", "Google": "GOOG", "Tesla": "TSLA", "Microsoft": "MSFT", "Nvidia": "NVDA", "Zeta Corp": "ZETA", "Alpha Tech": "ALPH", "Omega Industries": "OMEG", "Quasar Ltd.": "QUAS", "Neptune Systems": "NEPT", "Synex Solutions": "SYNX", "Amazon": "AMZN", } return {"symbol": symbol_map.get(name, "Stock not found")} def get_stock_info(self, symbol: str) -> Dict[str, Union[float, int, str]]: """ Get the details of a stock. Args: symbol (str): Symbol that uniquely identifies the stock. Returns: price (float): Current price of the stock. percent_change (float): Percentage change in stock price. volume (float): Trading volume of the stock. MA(5) (float): 5-day Moving Average of the stock. MA(20) (float): 20-day Moving Average of the stock. """ if symbol not in self.stocks: return {"error": f"Stock with symbol '{symbol}' not found."} if self.long_context: stock = self.stocks[symbol].copy() stock["MA(5)"] = MA_5_EXTENSION stock["MA(20)"] = MA_20_EXTENSION return stock return self.stocks[symbol] def get_order_details(self, order_id: int) -> Dict[str, Union[str, float, int]]: """ Get the details of an order. Args: order_id (int): ID of the order. Returns: id (int): ID of the order. order_type (str): Type of the order. symbol (str): Symbol of the stock in the order. price (float): Price at which the order was placed. amount (int): Number of shares in the order. status (str): Current status of the order. [Enum]: ["Open", "Pending", "Completed", "Cancelled"] """ if order_id not in self.orders: return { "error": f"Order with ID {order_id} not found." + "Here is the list of orders_id: " + str(list(self.orders.keys())) } if self.long_context: order = self.orders[order_id].copy() symbol = order["symbol"] formatted_extension = {} for key, value in ORDER_DETAIL_EXTENSION.items(): try: formatted_extension[key] = value.format(symbol=symbol) except KeyError as e: return {"error": f"KeyError during formatting: {str(e)}"} # Add formatted extension to the order metadata order["metadata"] = formatted_extension return order return self.orders[order_id] def cancel_order(self, order_id: int) -> Dict[str, Union[int, str]]: """ Cancel an order. Args: order_id (int): ID of the order to cancel. Returns: order_id (int): ID of the cancelled order. status (str): New status of the order after cancellation attempt. """ if order_id not in self.orders: return {"error": f"Order with ID {order_id} not found."} if self.orders[order_id]["status"] == "Completed": return {"error": f"Can't cancel order {order_id}. Order is already completed."} self.orders[order_id]["status"] = "Cancelled" return {"order_id": order_id, "status": "Cancelled"} def place_order( self, order_type: str, symbol: str, price: float, amount: int ) -> Dict[str, Union[int, str, float]]: """ Place an order. Args: order_type (str): Type of the order (Buy/Sell). symbol (str): Symbol of the stock to trade. price (float): Price at which to place the order. amount (int): Number of shares to trade. Returns: order_id (int): ID of the newly placed order. order_type (str): Type of the order (Buy/Sell). status (str): Initial status of the order. price (float): Price at which the order was placed. amount (int): Number of shares in the order. """ if not self.authenticated: return {"error": "User not authenticated. Please log in to place an order."} if symbol not in self.stocks: return {"error": f"Invalid stock symbol: {symbol}"} if price <= 0 or amount <= 0: return {"error": "Price and amount must be positive values."} price = float(price) order_id = self.order_counter self.orders[order_id] = { "id": order_id, "order_type": order_type, "symbol": symbol, "price": price, "amount": amount, "status": "Open", } self.order_counter += 1 # We return the status as "Pending" to indicate that the order has been placed but not yet executed # When polled later, the status will show as 'Open' # This is to simulate the delay between placing an order and it being executed return { "order_id": order_id, "order_type": order_type, "status": "Pending", "price": price, "amount": amount, } def make_transaction( self, account_id: int, xact_type: str, amount: float ) -> Dict[str, Union[str, float]]: """ Make a deposit or withdrawal based on specified amount. Args: account_id (int): ID of the account. xact_type (str): Transaction type (deposit or withdrawal). amount (float): Amount to deposit or withdraw. Returns: status (str): Status of the transaction. new_balance (float): Updated account balance after the transaction. """ if not self.authenticated: return {"error": "User not authenticated. Please log in to make a transaction."} if self.market_status != "Open": return {"error": "Market is closed. Transactions are not allowed."} if account_id != self.account_info["account_id"]: return {"error": f"Account with ID {account_id} not found."} if amount <= 0: return {"error": "Transaction amount must be positive."} if xact_type == "deposit": self.account_info["balance"] += amount self.transaction_history.append( { "type": "deposit", "amount": amount, "timestamp": self._generate_transaction_timestamp(), } ) return { "status": "Deposit successful", "new_balance": self.account_info["balance"], } elif xact_type == "withdrawal": if amount > self.account_info["balance"]: return {"error": "Insufficient funds for withdrawal."} self.account_info["balance"] -= amount self.transaction_history.append( { "type": "withdrawal", "amount": amount, "timestamp": self._generate_transaction_timestamp(), } ) return { "status": "Withdrawal successful", "new_balance": self.account_info["balance"], } return {"error": "Invalid transaction type. Use 'deposit' or 'withdrawal'."} def get_account_info(self) -> Dict[str, Union[int, float]]: """ Get account information. Returns: account_id (int): ID of the account. balance (float): Current balance of the account. binding_card (int): Card number associated with the account. """ if not self.authenticated: return { "error": "User not authenticated. Please log in to view account information." } return self.account_info def trading_login(self, username: str, password: str) -> Dict[str, str]: """ Handle user login. Args: username (str): Username for authentication. password (str): Password for authentication. Returns: status (str): Login status message. """ if self.authenticated: return {"status": "Already logged in"} # In a real system, we would validate the username and password here self.authenticated = True return {"status": "Logged in successfully"} def trading_get_login_status(self) -> Dict[str, bool]: """ Get the login status. Returns: status (bool): Login status. """ return {"status": bool(self.authenticated)} def trading_logout(self) -> Dict[str, str]: """ Handle user logout for trading system. Returns: status (str): Logout status message. """ if not self.authenticated: return {"status": "No user is currently logged in"} self.authenticated = False return {"status": "Logged out successfully"} def fund_account(self, amount: float) -> Dict[str, Union[str, float]]: """ Fund the account with the specified amount. Args: amount (float): Amount to fund the account with. Returns: status (str): Status of the funding operation. new_balance (float): Updated account balance after funding. """ if not self.authenticated: return {"error": "User not authenticated. Please log in to fund the account."} if amount <= 0: return {"error": "Funding amount must be positive."} self.account_info["balance"] += amount self.transaction_history.append( {"type": "deposit", "amount": amount, "timestamp": self._generate_transaction_timestamp()} ) return { "status": "Account funded successfully", "new_balance": self.account_info["balance"], } def remove_stock_from_watchlist(self, symbol: str) -> Dict[str, str]: """ Remove a stock from the watchlist. Args: symbol (str): Symbol of the stock to remove. Returns: status (str): Status of the removal operation. """ if not self.authenticated: return { "error": "User not authenticated. Please log in to modify the watchlist." } if symbol not in self.watch_list: return {"error": f"Stock {symbol} not found in watchlist."} self.watch_list.remove(symbol) return {"status": f"Stock {symbol} removed from watchlist successfully."} def get_watchlist(self) -> Dict[str, List[str]]: """ Get the watchlist. Returns: watchlist (List[str]): List of stock symbols in the watchlist. """ if not self.authenticated: return ["Error: User not authenticated. Please log in to view the watchlist."] if self.long_context: watch_list = self.watch_list.copy() watch_list.extend(WATCH_LIST_EXTENSION) return watch_list return {"watchlist": self.watch_list} def get_order_history(self) -> Dict[str, List[Dict[str, Union[str, int, float]]]]: """ Get the stock order ID history. Returns: order_history (List[int]): List of orders ID in the order history. """ if not self.authenticated: return [ { "error": "User not authenticated. Please log in to view order history." } ] return {"history": list(self.orders.keys())} def get_transaction_history( self, start_date: Optional[str] = None, end_date: Optional[str] = None ) -> Dict[str, List[Dict[str, Union[str, float]]]]: """ Get the transaction history within a specified date range. Args: start_date (str): [Optional] Start date for the history (format: 'YYYY-MM-DD'). end_date (str): [Optional] End date for the history (format: 'YYYY-MM-DD'). Returns: transaction_history (List[Dict]): List of transactions within the specified date range. - type (str): Type of transaction. [Enum]: ["deposit", "withdrawal"] - amount (float): Amount involved in the transaction. - timestamp (str): Timestamp of the transaction, formatted as 'YYYY-MM-DD HH:MM:SS'. """ if not self.authenticated: return [ { "error": "User not authenticated. Please log in to view transaction history." } ] if start_date: start = datetime.strptime(start_date, "%Y-%m-%d") else: start = datetime.min if end_date: end = datetime.strptime(end_date, "%Y-%m-%d") else: end = datetime.max filtered_history = [ transaction for transaction in self.transaction_history if start <= datetime.strptime(transaction["timestamp"], "%Y-%m-%d %H:%M:%S") <= end ] if self.long_context: filtered_history.extend(TRANSACTION_HISTORY_EXTENSION) return {"transaction_history": filtered_history} def update_stock_price( self, symbol: str, new_price: float ) -> Dict[str, Union[str, float]]: """ Update the price of a stock. Args: symbol (str): Symbol of the stock to update. new_price (float): New price of the stock. Returns: symbol (str): Symbol of the updated stock. old_price (float): Previous price of the stock. new_price (float): Updated price of the stock. """ if symbol not in self.stocks: return {"error": f"Stock with symbol '{symbol}' not found."} if new_price <= 0: return {"error": "New price must be a positive value."} old_price = self.stocks[symbol]["price"] self.stocks[symbol]["price"] = new_price self.stocks[symbol]["percent_change"] = ((new_price - old_price) / old_price) * 100 return {"symbol": symbol, "old_price": old_price, "new_price": new_price} # below contains a list of functions to be nested def get_available_stocks(self, sector: str) -> Dict[str, List[str]]: """ Get a list of stock symbols in the given sector. Args: sector (str): The sector to retrieve stocks from (e.g., 'Technology'). Returns: stock_list (List[str]): List of stock symbols in the specified sector. """ sector_map = { "Technology": ["AAPL", "GOOG", "MSFT", "NVDA"], "Automobile": ["TSLA", "F", "GM"], } if self.long_context: sector_map["Technology"].extend(TECHNOLOGY_EXTENSION) sector_map["Automobile"].extend(AUTOMOBILE_EXTENSION) return {"stock_list": sector_map.get(sector, [])} def filter_stocks_by_price( self, stocks: List[str], min_price: float, max_price: float ) -> Dict[str, List[str]]: """ Filter stocks based on a price range. Args: stocks (List[str]): List of stock symbols to filter. min_price (float): Minimum stock price. max_price (float): Maximum stock price. Returns: filtered_stocks (List[str]): Filtered list of stock symbols within the price range. """ filtered_stocks = [ symbol for symbol in stocks if self.stocks.get(symbol, {}).get("price", 0) >= min_price and self.stocks.get(symbol, {}).get("price", 0) <= max_price ] return {"filtered_stocks": filtered_stocks} def add_to_watchlist(self, stock: str) -> Dict[str, List[str]]: """ Add a stock to the watchlist. Args: stock (str): the stock symbol to add to the watchlist. Returns: symbol (str): the symbol that were successfully added to the watchlist. """ if stock not in self.watch_list: if stock in self.stocks: # Ensure symbol is valid self.watch_list.append(stock) return {"symbol": self.watch_list} def notify_price_change(self, stocks: List[str], threshold: float) -> Dict[str, str]: """ Notify if there is a significant price change in the stocks. Args: stocks (List[str]): List of stock symbols to check. threshold (float): Percentage change threshold to trigger a notification. Returns: notification (str): Notification message about the price changes. """ changed_stocks = [ symbol for symbol in stocks if symbol in self.stocks and abs(self.stocks[symbol]["percent_change"]) >= threshold ] if changed_stocks: return {"notification": f"Stocks {', '.join(changed_stocks)} have significant price changes."} else: return {"notification": "No significant price changes in the selected stocks."}