Huanzhi (Hans) Mao
init
4d1746c
import random
from copy import deepcopy
from datetime import datetime, time, timedelta
from typing import Dict, List, Optional, Union
from .long_context import (
AUTOMOBILE_EXTENSION,
MA_5_EXTENSION,
MA_20_EXTENSION,
ORDER_DETAIL_EXTENSION,
TECHNOLOGY_EXTENSION,
TRANSACTION_HISTORY_EXTENSION,
WATCH_LIST_EXTENSION,
)
CURRENT_TIME = datetime(2024, 9, 1, 10, 30)
DEFAULT_STATE = {
"orders": {
12345: {
"id": 12345,
"order_type": "Buy",
"symbol": "AAPL",
"price": 210.65,
"amount": 10,
"status": "Completed",
},
12446: {
"id": 12446,
"order_type": "Sell",
"symbol": "GOOG",
"price": 2840.56,
"amount": 5,
"status": "Pending",
},
},
"account_info": {
"account_id": 12345,
"balance": 10000.0,
"binding_card": 1974202140965533,
},
"authenticated": False,
"market_status": "Closed",
"order_counter": 12446,
"stocks": {
"AAPL": {
"price": 227.16,
"percent_change": 0.17,
"volume": 2.552,
"MA(5)": 227.11,
"MA(20)": 227.09,
},
"GOOG": {
"price": 2840.34,
"percent_change": 0.24,
"volume": 1.123,
"MA(5)": 2835.67,
"MA(20)": 2842.15,
},
"TSLA": {
"price": 667.92,
"percent_change": -0.12,
"volume": 1.654,
"MA(5)": 671.15,
"MA(20)": 668.20,
},
"MSFT": {
"price": 310.23,
"percent_change": 0.09,
"volume": 3.234,
"MA(5)": 309.88,
"MA(20)": 310.11,
},
"NVDA": {
"price": 220.34,
"percent_change": 0.34,
"volume": 1.234,
"MA(5)": 220.45,
"MA(20)": 220.67,
},
"ALPH": {
"price": 1320.45,
"percent_change": -0.08,
"volume": 1.567,
"MA(5)": 1321.12,
"MA(20)": 1325.78,
},
"OMEG": {
"price": 457.23,
"percent_change": 0.12,
"volume": 2.345,
"MA(5)": 456.78,
"MA(20)": 458.12,
},
"QUAS": {
"price": 725.89,
"percent_change": -0.03,
"volume": 1.789,
"MA(5)": 726.45,
"MA(20)": 728.00,
},
"NEPT": {
"price": 88.34,
"percent_change": 0.19,
"volume": 0.654,
"MA(5)": 88.21,
"MA(20)": 88.67,
},
"SYNX": {
"price": 345.67,
"percent_change": 0.11,
"volume": 2.112,
"MA(5)": 345.34,
"MA(20)": 346.12,
},
"ZETA": {
"price": 22.09,
"percent_change": -0.05,
"volume": 0.789,
"MA(5)": 22.12,
"MA(20)": 22.34,
},
},
"watch_list": ["NVDA"],
"transaction_history": [],
"random_seed": 1053520,
}
class TradingBot:
"""
A class representing a trading bot for executing stock trades and managing a trading account.
Attributes:
orders (Dict[int, Dict[str, Union[str, float, int]]]): A dictionary of orders for purchasing and selling of stock, keyed by order ID.
account_info (Dict[str, Union[int, float]]): Information about the trading account.
authenticated (bool): Whether the user is currently authenticated.
market_status (str): The current status of the market ('Open' or 'Closed').
order_counter (int): A counter for generating unique order IDs.
stocks (Dict[str, Dict[str, Union[float, int]]]): Information about various stocks.
watch_list (List[str]): A list of stock symbols being watched.
transaction_history (List[Dict[str, Union[str, float, int]]]): A history of trading account related transactions.
"""
def __init__(self):
"""
Initialize the TradingBot instance.
"""
self.orders: Dict[int, Dict[str, Union[str, float, int]]]
self.account_info: Dict[str, Union[int, float]]
self.authenticated: bool
self.market_status: str
self.order_counter: int
self.stocks: Dict[str, Dict[str, Union[float, int]]]
self.watch_list: List[str]
self.transaction_history: List[Dict[str, Union[str, float, int]]]
self._api_description = "This tool belongs to the trading system, which allows users to trade stocks, manage their account, and view stock information."
def _load_scenario(self, scenario: dict, long_context=False) -> None:
"""
Load a scenario into the TradingBot.
Args:
scenario (dict): A scenario dictionary containing data to load.
"""
DEFAULT_STATE_COPY = deepcopy(DEFAULT_STATE)
self.orders = scenario.get("orders", DEFAULT_STATE_COPY["orders"])
# Convert all string keys that can be interpreted as integers to integer keys
self.orders = {
int(k) if isinstance(k, str) and k.isdigit() else k: v
for k, v in self.orders.items()
}
self.account_info = scenario.get("account_info", DEFAULT_STATE_COPY["account_info"])
self.authenticated = scenario.get(
"authenticated", DEFAULT_STATE_COPY["authenticated"]
)
self.market_status = scenario.get(
"market_status", DEFAULT_STATE_COPY["market_status"]
)
self.order_counter = scenario.get(
"order_counter", DEFAULT_STATE_COPY["order_counter"]
) # Start counter from the next order ID
self.stocks = scenario.get("stocks", DEFAULT_STATE_COPY["stocks"])
self.watch_list = scenario.get("watch_list", DEFAULT_STATE_COPY["watch_list"])
self.transaction_history = scenario.get(
"transaction_history", DEFAULT_STATE_COPY["transaction_history"]
)
self.long_context = long_context
self._random = random.Random(
(scenario.get("random_seed", DEFAULT_STATE_COPY["random_seed"]))
)
def _generate_transaction_timestamp(self) -> str:
"""
Generate a timestamp for a transaction.
Returns:
timestamp (str): A formatted timestamp string.
"""
# Define the start and end dates for the range
start_date = CURRENT_TIME
end_date = CURRENT_TIME + timedelta(days=1)
start_timestamp = int(start_date.timestamp())
end_timestamp = int(end_date.timestamp())
# Generate a random timestamp within the range
random_timestamp = self._random.randint(start_timestamp, end_timestamp)
# Convert the random timestamp to a datetime object
random_date = datetime.fromtimestamp(random_timestamp)
return random_date.strftime("%Y-%m-%d %H:%M:%S")
def get_current_time(self) -> Dict[str, str]:
"""
Get the current time.
Returns:
current_time (str): Current time in HH:MM AM/PM format.
"""
return {"current_time": CURRENT_TIME.strftime("%I:%M %p")}
def update_market_status(self, current_time_str: str) -> Dict[str, str]:
"""
Update the market status based on the current time.
Args:
current_time_str (str): Current time in HH:MM AM/PM format.
Returns:
status (str): Status of the market. [Enum]: ["Open", "Closed"]
"""
market_open_time = time(9, 30) # Market opens at 9:30 AM
market_close_time = time(16, 0) # Market closes at 4:00 PM
current_time = datetime.strptime(current_time_str, "%I:%M %p").time()
if market_open_time <= current_time <= market_close_time:
self.market_status = "Open"
return {"status": "Open"}
else:
self.market_status = "Closed"
return {"status": "Closed"}
def get_symbol_by_name(self, name: str) -> Dict[str, str]:
"""
Get the symbol of a stock by company name.
Args:
name (str): Name of the company.
Returns:
symbol (str): Symbol of the stock or "Stock not found" if not available.
"""
symbol_map = {
"Apple": "AAPL",
"Google": "GOOG",
"Tesla": "TSLA",
"Microsoft": "MSFT",
"Nvidia": "NVDA",
"Zeta Corp": "ZETA",
"Alpha Tech": "ALPH",
"Omega Industries": "OMEG",
"Quasar Ltd.": "QUAS",
"Neptune Systems": "NEPT",
"Synex Solutions": "SYNX",
"Amazon": "AMZN",
}
return {"symbol": symbol_map.get(name, "Stock not found")}
def get_stock_info(self, symbol: str) -> Dict[str, Union[float, int, str]]:
"""
Get the details of a stock.
Args:
symbol (str): Symbol that uniquely identifies the stock.
Returns:
price (float): Current price of the stock.
percent_change (float): Percentage change in stock price.
volume (float): Trading volume of the stock.
MA(5) (float): 5-day Moving Average of the stock.
MA(20) (float): 20-day Moving Average of the stock.
"""
if symbol not in self.stocks:
return {"error": f"Stock with symbol '{symbol}' not found."}
if self.long_context:
stock = self.stocks[symbol].copy()
stock["MA(5)"] = MA_5_EXTENSION
stock["MA(20)"] = MA_20_EXTENSION
return stock
return self.stocks[symbol]
def get_order_details(self, order_id: int) -> Dict[str, Union[str, float, int]]:
"""
Get the details of an order.
Args:
order_id (int): ID of the order.
Returns:
id (int): ID of the order.
order_type (str): Type of the order.
symbol (str): Symbol of the stock in the order.
price (float): Price at which the order was placed.
amount (int): Number of shares in the order.
status (str): Current status of the order. [Enum]: ["Open", "Pending", "Completed", "Cancelled"]
"""
if order_id not in self.orders:
return {
"error": f"Order with ID {order_id} not found."
+ "Here is the list of orders_id: "
+ str(list(self.orders.keys()))
}
if self.long_context:
order = self.orders[order_id].copy()
symbol = order["symbol"]
formatted_extension = {}
for key, value in ORDER_DETAIL_EXTENSION.items():
try:
formatted_extension[key] = value.format(symbol=symbol)
except KeyError as e:
return {"error": f"KeyError during formatting: {str(e)}"}
# Add formatted extension to the order metadata
order["metadata"] = formatted_extension
return order
return self.orders[order_id]
def cancel_order(self, order_id: int) -> Dict[str, Union[int, str]]:
"""
Cancel an order.
Args:
order_id (int): ID of the order to cancel.
Returns:
order_id (int): ID of the cancelled order.
status (str): New status of the order after cancellation attempt.
"""
if order_id not in self.orders:
return {"error": f"Order with ID {order_id} not found."}
if self.orders[order_id]["status"] == "Completed":
return {"error": f"Can't cancel order {order_id}. Order is already completed."}
self.orders[order_id]["status"] = "Cancelled"
return {"order_id": order_id, "status": "Cancelled"}
def place_order(
self, order_type: str, symbol: str, price: float, amount: int
) -> Dict[str, Union[int, str, float]]:
"""
Place an order.
Args:
order_type (str): Type of the order (Buy/Sell).
symbol (str): Symbol of the stock to trade.
price (float): Price at which to place the order.
amount (int): Number of shares to trade.
Returns:
order_id (int): ID of the newly placed order.
order_type (str): Type of the order (Buy/Sell).
status (str): Initial status of the order.
price (float): Price at which the order was placed.
amount (int): Number of shares in the order.
"""
if not self.authenticated:
return {"error": "User not authenticated. Please log in to place an order."}
if symbol not in self.stocks:
return {"error": f"Invalid stock symbol: {symbol}"}
if price <= 0 or amount <= 0:
return {"error": "Price and amount must be positive values."}
price = float(price)
order_id = self.order_counter
self.orders[order_id] = {
"id": order_id,
"order_type": order_type,
"symbol": symbol,
"price": price,
"amount": amount,
"status": "Open",
}
self.order_counter += 1
# We return the status as "Pending" to indicate that the order has been placed but not yet executed
# When polled later, the status will show as 'Open'
# This is to simulate the delay between placing an order and it being executed
return {
"order_id": order_id,
"order_type": order_type,
"status": "Pending",
"price": price,
"amount": amount,
}
def make_transaction(
self, account_id: int, xact_type: str, amount: float
) -> Dict[str, Union[str, float]]:
"""
Make a deposit or withdrawal based on specified amount.
Args:
account_id (int): ID of the account.
xact_type (str): Transaction type (deposit or withdrawal).
amount (float): Amount to deposit or withdraw.
Returns:
status (str): Status of the transaction.
new_balance (float): Updated account balance after the transaction.
"""
if not self.authenticated:
return {"error": "User not authenticated. Please log in to make a transaction."}
if self.market_status != "Open":
return {"error": "Market is closed. Transactions are not allowed."}
if account_id != self.account_info["account_id"]:
return {"error": f"Account with ID {account_id} not found."}
if amount <= 0:
return {"error": "Transaction amount must be positive."}
if xact_type == "deposit":
self.account_info["balance"] += amount
self.transaction_history.append(
{
"type": "deposit",
"amount": amount,
"timestamp": self._generate_transaction_timestamp(),
}
)
return {
"status": "Deposit successful",
"new_balance": self.account_info["balance"],
}
elif xact_type == "withdrawal":
if amount > self.account_info["balance"]:
return {"error": "Insufficient funds for withdrawal."}
self.account_info["balance"] -= amount
self.transaction_history.append(
{
"type": "withdrawal",
"amount": amount,
"timestamp": self._generate_transaction_timestamp(),
}
)
return {
"status": "Withdrawal successful",
"new_balance": self.account_info["balance"],
}
return {"error": "Invalid transaction type. Use 'deposit' or 'withdrawal'."}
def get_account_info(self) -> Dict[str, Union[int, float]]:
"""
Get account information.
Returns:
account_id (int): ID of the account.
balance (float): Current balance of the account.
binding_card (int): Card number associated with the account.
"""
if not self.authenticated:
return {
"error": "User not authenticated. Please log in to view account information."
}
return self.account_info
def trading_login(self, username: str, password: str) -> Dict[str, str]:
"""
Handle user login.
Args:
username (str): Username for authentication.
password (str): Password for authentication.
Returns:
status (str): Login status message.
"""
if self.authenticated:
return {"status": "Already logged in"}
# In a real system, we would validate the username and password here
self.authenticated = True
return {"status": "Logged in successfully"}
def trading_get_login_status(self) -> Dict[str, bool]:
"""
Get the login status.
Returns:
status (bool): Login status.
"""
return {"status": bool(self.authenticated)}
def trading_logout(self) -> Dict[str, str]:
"""
Handle user logout for trading system.
Returns:
status (str): Logout status message.
"""
if not self.authenticated:
return {"status": "No user is currently logged in"}
self.authenticated = False
return {"status": "Logged out successfully"}
def fund_account(self, amount: float) -> Dict[str, Union[str, float]]:
"""
Fund the account with the specified amount.
Args:
amount (float): Amount to fund the account with.
Returns:
status (str): Status of the funding operation.
new_balance (float): Updated account balance after funding.
"""
if not self.authenticated:
return {"error": "User not authenticated. Please log in to fund the account."}
if amount <= 0:
return {"error": "Funding amount must be positive."}
self.account_info["balance"] += amount
self.transaction_history.append(
{"type": "deposit", "amount": amount, "timestamp": self._generate_transaction_timestamp()}
)
return {
"status": "Account funded successfully",
"new_balance": self.account_info["balance"],
}
def remove_stock_from_watchlist(self, symbol: str) -> Dict[str, str]:
"""
Remove a stock from the watchlist.
Args:
symbol (str): Symbol of the stock to remove.
Returns:
status (str): Status of the removal operation.
"""
if not self.authenticated:
return {
"error": "User not authenticated. Please log in to modify the watchlist."
}
if symbol not in self.watch_list:
return {"error": f"Stock {symbol} not found in watchlist."}
self.watch_list.remove(symbol)
return {"status": f"Stock {symbol} removed from watchlist successfully."}
def get_watchlist(self) -> Dict[str, List[str]]:
"""
Get the watchlist.
Returns:
watchlist (List[str]): List of stock symbols in the watchlist.
"""
if not self.authenticated:
return ["Error: User not authenticated. Please log in to view the watchlist."]
if self.long_context:
watch_list = self.watch_list.copy()
watch_list.extend(WATCH_LIST_EXTENSION)
return watch_list
return {"watchlist": self.watch_list}
def get_order_history(self) -> Dict[str, List[Dict[str, Union[str, int, float]]]]:
"""
Get the stock order ID history.
Returns:
order_history (List[int]): List of orders ID in the order history.
"""
if not self.authenticated:
return [
{
"error": "User not authenticated. Please log in to view order history."
}
]
return {"history": list(self.orders.keys())}
def get_transaction_history(
self, start_date: Optional[str] = None, end_date: Optional[str] = None
) -> Dict[str, List[Dict[str, Union[str, float]]]]:
"""
Get the transaction history within a specified date range.
Args:
start_date (str): [Optional] Start date for the history (format: 'YYYY-MM-DD').
end_date (str): [Optional] End date for the history (format: 'YYYY-MM-DD').
Returns:
transaction_history (List[Dict]): List of transactions within the specified date range.
- type (str): Type of transaction. [Enum]: ["deposit", "withdrawal"]
- amount (float): Amount involved in the transaction.
- timestamp (str): Timestamp of the transaction, formatted as 'YYYY-MM-DD HH:MM:SS'.
"""
if not self.authenticated:
return [
{
"error": "User not authenticated. Please log in to view transaction history."
}
]
if start_date:
start = datetime.strptime(start_date, "%Y-%m-%d")
else:
start = datetime.min
if end_date:
end = datetime.strptime(end_date, "%Y-%m-%d")
else:
end = datetime.max
filtered_history = [
transaction
for transaction in self.transaction_history
if start
<= datetime.strptime(transaction["timestamp"], "%Y-%m-%d %H:%M:%S")
<= end
]
if self.long_context:
filtered_history.extend(TRANSACTION_HISTORY_EXTENSION)
return {"transaction_history": filtered_history}
def update_stock_price(
self, symbol: str, new_price: float
) -> Dict[str, Union[str, float]]:
"""
Update the price of a stock.
Args:
symbol (str): Symbol of the stock to update.
new_price (float): New price of the stock.
Returns:
symbol (str): Symbol of the updated stock.
old_price (float): Previous price of the stock.
new_price (float): Updated price of the stock.
"""
if symbol not in self.stocks:
return {"error": f"Stock with symbol '{symbol}' not found."}
if new_price <= 0:
return {"error": "New price must be a positive value."}
old_price = self.stocks[symbol]["price"]
self.stocks[symbol]["price"] = new_price
self.stocks[symbol]["percent_change"] = ((new_price - old_price) / old_price) * 100
return {"symbol": symbol, "old_price": old_price, "new_price": new_price}
# below contains a list of functions to be nested
def get_available_stocks(self, sector: str) -> Dict[str, List[str]]:
"""
Get a list of stock symbols in the given sector.
Args:
sector (str): The sector to retrieve stocks from (e.g., 'Technology').
Returns:
stock_list (List[str]): List of stock symbols in the specified sector.
"""
sector_map = {
"Technology": ["AAPL", "GOOG", "MSFT", "NVDA"],
"Automobile": ["TSLA", "F", "GM"],
}
if self.long_context:
sector_map["Technology"].extend(TECHNOLOGY_EXTENSION)
sector_map["Automobile"].extend(AUTOMOBILE_EXTENSION)
return {"stock_list": sector_map.get(sector, [])}
def filter_stocks_by_price(
self, stocks: List[str], min_price: float, max_price: float
) -> Dict[str, List[str]]:
"""
Filter stocks based on a price range.
Args:
stocks (List[str]): List of stock symbols to filter.
min_price (float): Minimum stock price.
max_price (float): Maximum stock price.
Returns:
filtered_stocks (List[str]): Filtered list of stock symbols within the price range.
"""
filtered_stocks = [
symbol
for symbol in stocks
if self.stocks.get(symbol, {}).get("price", 0) >= min_price
and self.stocks.get(symbol, {}).get("price", 0) <= max_price
]
return {"filtered_stocks": filtered_stocks}
def add_to_watchlist(self, stock: str) -> Dict[str, List[str]]:
"""
Add a stock to the watchlist.
Args:
stock (str): the stock symbol to add to the watchlist.
Returns:
symbol (str): the symbol that were successfully added to the watchlist.
"""
if stock not in self.watch_list:
if stock in self.stocks: # Ensure symbol is valid
self.watch_list.append(stock)
return {"symbol": self.watch_list}
def notify_price_change(self, stocks: List[str], threshold: float) -> Dict[str, str]:
"""
Notify if there is a significant price change in the stocks.
Args:
stocks (List[str]): List of stock symbols to check.
threshold (float): Percentage change threshold to trigger a notification.
Returns:
notification (str): Notification message about the price changes.
"""
changed_stocks = [
symbol
for symbol in stocks
if symbol in self.stocks
and abs(self.stocks[symbol]["percent_change"]) >= threshold
]
if changed_stocks:
return {"notification": f"Stocks {', '.join(changed_stocks)} have significant price changes."}
else:
return {"notification": "No significant price changes in the selected stocks."}