finance-bot / pages /Stock_Chatbot.py
tosanoob's picture
feat: add news report
44f095f
# app.py (Final version with Advanced Charts)
import streamlit as st
import pandas as pd
import altair as alt # <-- Add Altair library
import google.generativeai as genai
import google.ai.generativelanguage as glm
from dotenv import load_dotenv
import os
from twelvedata_api import TwelveDataAPI
from collections import deque
from datetime import datetime
# --- 1. INITIAL CONFIGURATION & STATE INITIALIZATION ---
load_dotenv()
MODEL_NAME = os.getenv("GEMINI_MODEL", "gemini-2.5-flash")
# Set page config consistent with other pages
st.set_page_config(
page_title="AI Financial Dashboard",
page_icon="📊",
layout="wide"
)
# Configure Gemini API
genai.configure(api_key=os.getenv("GEMINI_API_KEY"))
def initialize_state():
if "initialized" in st.session_state: return
st.session_state.initialized = True
st.session_state.td_api = TwelveDataAPI(os.getenv("TWELVEDATA_API_KEY"))
st.session_state.stock_watchlist = {}
st.session_state.timeseries_cache = {}
st.session_state.active_timeseries_period = 'intraday'
st.session_state.currency_converter_state = {'from': 'USD', 'to': 'VND', 'amount': 100.0, 'result': None}
st.session_state.chat_history = []
st.session_state.active_tab = 'Stock Watchlist'
st.session_state.chat_session = None
initialize_state()
# --- 2. LOAD BACKGROUND DATA ---
@st.cache_data(show_spinner="Loading and preparing market data...")
def load_market_data():
td_api = st.session_state.td_api
stocks_data = td_api.get_all_stocks()
forex_data = td_api.get_forex_pairs()
forex_graph = {}
if forex_data and 'data' in forex_data:
for pair in forex_data['data']:
base, quote = pair['symbol'].split('/'); forex_graph.setdefault(base, []); forex_graph.setdefault(quote, []); forex_graph[base].append(quote); forex_graph[quote].append(base)
country_currency_map = {}
if stocks_data and 'data' in stocks_data:
for stock in stocks_data['data']:
country, currency = stock.get('country'), stock.get('currency')
if country and currency: country_currency_map[country.lower()] = currency
all_currencies = sorted(forex_graph.keys())
return stocks_data, forex_graph, country_currency_map, all_currencies
ALL_STOCKS_CACHE, FOREX_GRAPH, COUNTRY_CURRENCY_MAP, AVAILABLE_CURRENCIES = load_market_data()
# --- 3. TOOL EXECUTION LOGIC ---
def find_and_process_stock(query: str):
print(f"Hybrid searching for stock: '{query}'...")
query_lower = query.lower()
found_data = [s for s in ALL_STOCKS_CACHE.get('data', []) if query_lower in s['symbol'].lower() or query_lower in s['name'].lower()]
if not found_data:
results = st.session_state.td_api.get_stocks(symbol=query)
found_data = results.get('data', [])
if len(found_data) == 1:
stock_info = found_data[0]; symbol = stock_info['symbol']
st.session_state.stock_watchlist[symbol] = stock_info
ts_data = get_smart_time_series(symbol=symbol, time_period='intraday')
if 'values' in ts_data:
df = pd.DataFrame(ts_data['values']); df['datetime'] = pd.to_datetime(df['datetime']); df['close'] = pd.to_numeric(df['close'])
if symbol not in st.session_state.timeseries_cache: st.session_state.timeseries_cache[symbol] = {}
st.session_state.timeseries_cache[symbol]['intraday'] = df.sort_values('datetime').set_index('datetime')
st.session_state.active_tab = 'Time Charts'; st.session_state.active_timeseries_period = 'intraday'
return {"status": "SINGLE_STOCK_PROCESSED", "symbol": symbol, "name": stock_info.get('name', 'N/A')}
elif len(found_data) > 1: return {"status": "MULTIPLE_STOCKS_FOUND", "data": found_data[:5]}
else: return {"status": "NO_STOCKS_FOUND"}
def get_smart_time_series(symbol: str, time_period: str):
# Kiểm tra nếu symbol chưa có trong watchlist thì thêm vào trước
if symbol not in st.session_state.stock_watchlist:
# Tìm thông tin cổ phiếu và thêm vào watchlist
results = st.session_state.td_api.get_stocks(symbol=symbol)
found_data = results.get('data', [])
if found_data:
stock_info = found_data[0]
st.session_state.stock_watchlist[symbol] = stock_info
st.session_state.active_tab = 'Time Charts'
logic_map = {'intraday': {'interval': '15min', 'outputsize': 120}, '1_week': {'interval': '1h', 'outputsize': 40}, '1_month': {'interval': '1day', 'outputsize': 22}, '6_months': {'interval': '1day', 'outputsize': 120}, '1_year': {'interval': '1week', 'outputsize': 52}}
params = logic_map.get(time_period)
if not params: return {"error": f"Time period '{time_period}' is not valid."}
result = st.session_state.td_api.get_time_series(symbol=symbol, **params)
# Nếu kết quả thành công, cập nhật cache
if 'values' in result:
df = pd.DataFrame(result['values'])
df['datetime'] = pd.to_datetime(df['datetime'])
df['close'] = pd.to_numeric(df['close'])
if symbol not in st.session_state.timeseries_cache:
st.session_state.timeseries_cache[symbol] = {}
st.session_state.timeseries_cache[symbol][time_period] = df.sort_values('datetime').set_index('datetime')
return result
def find_conversion_path_bfs(start, end):
if start not in FOREX_GRAPH or end not in FOREX_GRAPH: return None
q = deque([(start, [start])]); visited = {start}
while q:
curr, path = q.popleft()
if curr == end: return path
for neighbor in FOREX_GRAPH.get(curr, []):
if neighbor not in visited: visited.add(neighbor); q.append((neighbor, path + [neighbor]))
return None
def convert_currency_with_bridge(amount: float, symbol: str):
try: start_currency, end_currency = symbol.upper().split('/')
except ValueError: return {"error": "Invalid currency pair format."}
path = find_conversion_path_bfs(start_currency, end_currency)
if not path: return {"error": f"No conversion path found from {start_currency} to {end_currency}."}
current_amount = amount; steps = []
for i in range(len(path) - 1):
step_start, step_end = path[i], path[i+1]
result = st.session_state.td_api.currency_conversion(amount=current_amount, symbol=f"{step_start}/{step_end}")
if 'rate' in result and result.get('rate') is not None:
current_amount = result['amount']; steps.append({"step": f"{i+1}. {step_start}{step_end}", "rate": result['rate'], "intermediate_amount": current_amount})
else:
inverse_result = st.session_state.td_api.currency_conversion(amount=1, symbol=f"{step_end}/{step_start}")
if 'rate' in inverse_result and inverse_result.get('rate') and inverse_result['rate'] != 0:
rate = 1 / inverse_result['rate']; current_amount *= rate; steps.append({"step": f"{i+1}. {step_start}{step_end} (Inverse)", "rate": rate, "intermediate_amount": current_amount})
else: return {"error": f"Error in conversion step from {step_start} to {step_end}."}
return {"status": "Success", "original_amount": amount, "final_amount": current_amount, "path_taken": path, "conversion_steps": steps}
def perform_currency_conversion(amount: float, symbol: str):
result = convert_currency_with_bridge(amount, symbol)
st.session_state.currency_converter_state.update({'result': result, 'amount': amount})
try:
from_curr, to_curr = symbol.split('/'); st.session_state.currency_converter_state.update({'from': from_curr, 'to': to_curr})
except: pass
st.session_state.active_tab = 'Currency Converter'
return result
# --- 4. GEMINI CONFIGURATION ---
SYSTEM_INSTRUCTION = """You are the AI brain controlling an Interactive Financial Dashboard. Your task is to understand user requests, call appropriate tools, and report results concisely.
GOLDEN RULES:
1. **UNDERSTAND FIRST, CALL LATER:**
* **Company Name:** When a user enters a company name (e.g., "Vingroup Corporation", "Apple"), your FIRST task is to use the `find_and_process_stock` tool to identify the official stock symbol.
* **Stock Symbol:** When a user directly provides a stock symbol (e.g., "AAPL", "VNM"), use `find_and_process_stock` first to confirm and add it to the watchlist.
* **Time Period Request:** When user asks for specific time period (e.g., "last year", "last month"), first make sure the stock symbol is processed with `find_and_process_stock`, then use `get_smart_time_series` with appropriate time_period.
* **Country Name:** When a user enters a country name for currency (e.g., "Vietnamese currency"), you must infer the 3-letter currency code ("VND") BEFORE calling the `perform_currency_conversion` tool.
2. **ACT AND NOTIFY:** Your role is to execute commands and report briefly.
* **Found 1 symbol:** "I've found [Company Name] ([Symbol]) and automatically added it to your watchlist and chart."
* **Found multiple symbols:** "I found several results for '[query]'. Please specify which exact symbol you want to track?"
* **Currency conversion:** "Done. Please see detailed results in the 'Currency Converter' tab."
3. **NO DATA LISTING:** The dashboard already displays everything. ABSOLUTELY do not repeat lists, numbers, or raw data in your response.
"""
@st.cache_resource
def get_model_and_tools():
find_stock_func = glm.FunctionDeclaration(name="find_and_process_stock", description="Search for stock by symbol or name and automatically process. Use this tool FIRST to identify the official stock symbol.", parameters=glm.Schema(type=glm.Type.OBJECT, properties={'query': glm.Schema(type=glm.Type.STRING, description="Symbol or company name, e.g., 'Vingroup', 'Apple'.")}, required=['query']))
get_ts_func = glm.FunctionDeclaration(name="get_smart_time_series", description="Get price history data after knowing the official stock symbol.", parameters=glm.Schema(type=glm.Type.OBJECT, properties={'symbol': glm.Schema(type=glm.Type.STRING), 'time_period': glm.Schema(type=glm.Type.STRING, enum=["intraday", "1_week", "1_month", "6_months", "1_year"])}, required=['symbol', 'time_period']))
currency_func = glm.FunctionDeclaration(name="perform_currency_conversion", description="Convert currency after knowing the 3-letter code of source/target currency pair, e.g., USD/VND, JPY/EUR", parameters=glm.Schema(type=glm.Type.OBJECT, properties={'amount': glm.Schema(type=glm.Type.NUMBER), 'symbol': glm.Schema(type=glm.Type.STRING)}, required=['amount', 'symbol']))
finance_tool = glm.Tool(function_declarations=[find_stock_func, get_ts_func, currency_func])
model = genai.GenerativeModel(model_name=MODEL_NAME, tools=[finance_tool], system_instruction=SYSTEM_INSTRUCTION)
return model
model = get_model_and_tools()
if st.session_state.chat_session is None:
st.session_state.chat_session = model.start_chat(history=[])
AVAILABLE_FUNCTIONS = {"find_and_process_stock": find_and_process_stock, "get_smart_time_series": get_smart_time_series, "perform_currency_conversion": perform_currency_conversion}
# --- 5. TAB DISPLAY LOGIC ---
def get_y_axis_domain(series: pd.Series, padding_percent: float = 0.1):
if series.empty: return None
data_min, data_max = series.min(), series.max()
if pd.isna(data_min) or pd.isna(data_max): return None
data_range = data_max - data_min
if data_range == 0:
padding = abs(data_max * (padding_percent / 2))
return [data_min - padding, data_max + padding]
padding = data_range * padding_percent
return [data_min - padding, data_max + padding]
def render_watchlist_tab():
st.subheader("Watchlist")
if not st.session_state.stock_watchlist: st.info("No stocks yet. Try searching for a symbol like 'Apple' or 'VNM'."); return
for symbol, stock_info in list(st.session_state.stock_watchlist.items()):
col1, col2, col3 = st.columns([4, 4, 1])
with col1: st.markdown(f"**{symbol}**"); st.caption(stock_info.get('name', 'N/A'))
with col2: st.markdown(f"**{stock_info.get('exchange', 'N/A')}**"); st.caption(f"{stock_info.get('country', 'N/A')} - {stock_info.get('currency', 'N/A')}")
with col3:
if st.button("🗑️", key=f"delete_{symbol}", help=f"Delete {symbol}"):
st.session_state.stock_watchlist.pop(symbol, None); st.session_state.timeseries_cache.pop(symbol, None); st.rerun()
st.divider()
def render_timeseries_tab():
st.subheader("Chart Analysis")
if not st.session_state.stock_watchlist:
st.info("Please add at least one stock to the watchlist to view charts."); return
time_periods = {'Intraday': 'intraday', '1 Week': '1_week', '1 Month': '1_month', '6 Months': '6_months', '1 Year': '1_year'}
period_keys = list(time_periods.keys())
period_values = list(time_periods.values())
default_index = period_values.index(st.session_state.active_timeseries_period) if st.session_state.active_timeseries_period in period_values else 0
selected_label = st.radio("Select time period:", options=period_keys, horizontal=True, index=default_index)
selected_period = time_periods[selected_label]
if st.session_state.active_timeseries_period != selected_period:
st.session_state.active_timeseries_period = selected_period
with st.spinner(f"Updating charts..."):
for symbol in st.session_state.stock_watchlist.keys():
ts_data = get_smart_time_series(symbol, selected_period)
if 'values' in ts_data:
df = pd.DataFrame(ts_data['values']); df['datetime'] = pd.to_datetime(df['datetime']); df['close'] = pd.to_numeric(df['close'])
if symbol not in st.session_state.timeseries_cache: st.session_state.timeseries_cache[symbol] = {}
st.session_state.timeseries_cache[symbol][selected_period] = df.sort_values('datetime').set_index('datetime')
st.rerun()
all_series_data = {symbol: st.session_state.timeseries_cache[symbol][selected_period] for symbol in st.session_state.stock_watchlist.keys() if symbol in st.session_state.timeseries_cache and selected_period in st.session_state.timeseries_cache[symbol]}
if not all_series_data:
st.warning("Not enough data for the selected time period."); return
st.markdown("##### Growth Performance Comparison (%)")
normalized_dfs = []
for symbol, df in all_series_data.items():
if not df.empty:
normalized_series = (df['close'] / df['close'].iloc[0]) * 100
normalized_df = normalized_series.reset_index(); normalized_df.columns = ['datetime', 'value']; normalized_df['symbol'] = symbol
normalized_dfs.append(normalized_df)
if normalized_dfs:
full_normalized_df = pd.concat(normalized_dfs)
y_domain = get_y_axis_domain(full_normalized_df['value'])
chart = alt.Chart(full_normalized_df).mark_line().encode(x=alt.X('datetime:T', title='Time'), y=alt.Y('value:Q', scale=alt.Scale(domain=y_domain, zero=False), title='Growth (%)'), color=alt.Color('symbol:N', title='Symbol'), tooltip=[alt.Tooltip('symbol:N', title='Symbol'), alt.Tooltip('datetime:T', title='Time', format='%Y-%m-%d %H:%M'), alt.Tooltip('value:Q', title='Growth', format='.2f')]).interactive()
st.altair_chart(chart, use_container_width=True)
else:
st.warning("No data to draw growth chart.")
st.divider()
st.markdown("##### Actual Price Charts")
for symbol, df in all_series_data.items():
stock_info = st.session_state.stock_watchlist.get(symbol, {})
st.markdown(f"**{symbol}** ({stock_info.get('currency', 'N/A')})")
if not df.empty:
y_domain = get_y_axis_domain(df['close'])
data_for_chart = df.reset_index()
price_chart = alt.Chart(data_for_chart).mark_line().encode(x=alt.X('datetime:T', title='Time'), y=alt.Y('close:Q', scale=alt.Scale(domain=y_domain, zero=False), title='Price'), tooltip=[alt.Tooltip('datetime:T', title='Time', format='%Y-%m-%d %H:%M'), alt.Tooltip('close:Q', title='Price', format=',.2f')]).interactive()
st.altair_chart(price_chart, use_container_width=True)
def render_currency_tab():
st.subheader("Currency Converter Tool"); state = st.session_state.currency_converter_state
col1, col2 = st.columns(2)
amount = col1.number_input("Amount", value=state['amount'], min_value=0.0, format="%.2f", key="conv_amount")
from_curr = col1.selectbox("From", options=AVAILABLE_CURRENCIES, index=AVAILABLE_CURRENCIES.index(state['from']) if state['from'] in AVAILABLE_CURRENCIES else 0, key="conv_from")
to_curr = col2.selectbox("To", options=AVAILABLE_CURRENCIES, index=AVAILABLE_CURRENCIES.index(state['to']) if state['to'] in AVAILABLE_CURRENCIES else 1, key="conv_to")
if st.button("Convert", use_container_width=True, key="conv_btn"):
with st.spinner("Converting..."): result = perform_currency_conversion(amount, f"{from_curr}/{to_curr}"); st.rerun()
if state['result']:
res = state['result']
if res.get('status') == 'Success': st.success(f"**Result:** `{res['original_amount']:,.2f} {res['path_taken'][0]}` = `{res['final_amount']:,.2f} {res['path_taken'][-1]}`")
else: st.error(f"Error: {res.get('error', 'Unknown')}")
# --- 6. MAIN APP LAYOUT & CONTROL FLOW ---
st.title("📈 AI Financial Dashboard")
# Chia bố cục thành hai cột
col1, col2 = st.columns([1, 1])
# Cột bên trái cho chat với AI
with col1:
chat_container = st.container(height=600)
with chat_container:
for message in st.session_state.chat_history:
with st.chat_message(message["role"]):
st.markdown(message["parts"])
# Cột bên phải cho tab biểu đồ và dữ liệu
with col2:
right_column_container = st.container(height=600)
with right_column_container:
tab_names = ['Stock Watchlist', 'Time Charts', 'Currency Converter']
try: default_index = tab_names.index(st.session_state.active_tab)
except ValueError: default_index = 0
st.session_state.active_tab = tab_names[default_index]
tab1, tab2, tab3 = st.tabs(tab_names)
with tab1: render_watchlist_tab()
with tab2: render_timeseries_tab()
with tab3: render_currency_tab()
# Input chat nằm dưới cùng
user_prompt = st.chat_input("Ask AI to control the dashboard...")
if user_prompt:
st.session_state.chat_history.append({"role": "user", "parts": user_prompt})
st.rerun()
# Xử lý câu hỏi của người dùng và hiển thị phản hồi AI
if st.session_state.chat_history and st.session_state.chat_history[-1]["role"] == "user":
last_user_prompt = st.session_state.chat_history[-1]["parts"]
with chat_container:
with st.chat_message("model"):
with st.spinner("🤖 AI executing command..."):
response = st.session_state.chat_session.send_message(last_user_prompt)
tool_calls = [part.function_call for part in response.candidates[0].content.parts if part.function_call]
while tool_calls:
tool_responses = []
for call in tool_calls:
func_name = call.name; func_args = {k: v for k, v in call.args.items()}
if func_name in AVAILABLE_FUNCTIONS:
tool_result = AVAILABLE_FUNCTIONS[func_name](**func_args)
tool_responses.append(glm.Part(function_response=glm.FunctionResponse(name=func_name, response={'result': tool_result})))
else:
tool_responses.append(glm.Part(function_response=glm.FunctionResponse(name=func_name, response={'error': f"Function '{func_name}' not found."})))
response = st.session_state.chat_session.send_message(glm.Content(parts=tool_responses))
tool_calls = [part.function_call for part in response.candidates[0].content.parts if part.function_call]
# Tìm kiếm từ khóa thời gian trong prompt của người dùng
old_period = st.session_state.active_timeseries_period
if last_user_prompt and "last year" in last_user_prompt.lower():
st.session_state.active_timeseries_period = "1_year"
elif last_user_prompt and "last 6 months" in last_user_prompt.lower():
st.session_state.active_timeseries_period = "6_months"
elif last_user_prompt and "last month" in last_user_prompt.lower():
st.session_state.active_timeseries_period = "1_month"
elif last_user_prompt and "last week" in last_user_prompt.lower():
st.session_state.active_timeseries_period = "1_week"
# Nếu thời gian thay đổi và có cổ phiếu trong watchlist, cập nhật dữ liệu
if old_period != st.session_state.active_timeseries_period and st.session_state.stock_watchlist:
new_period = st.session_state.active_timeseries_period
for symbol in st.session_state.stock_watchlist.keys():
if symbol not in st.session_state.timeseries_cache or new_period not in st.session_state.timeseries_cache[symbol]:
ts_data = get_smart_time_series(symbol, new_period)
st.session_state.active_tab = 'Time Charts'
st.session_state.chat_history.append({"role": "model", "parts": response.text})
st.rerun()