Spaces:
Sleeping
Sleeping
# app.py (Final version with Advanced Charts) | |
import streamlit as st | |
import pandas as pd | |
import altair as alt # <-- Add Altair library | |
import google.generativeai as genai | |
import google.ai.generativelanguage as glm | |
from dotenv import load_dotenv | |
import os | |
from twelvedata_api import TwelveDataAPI | |
from collections import deque | |
from datetime import datetime | |
# --- 1. INITIAL CONFIGURATION & STATE INITIALIZATION --- | |
load_dotenv() | |
MODEL_NAME = os.getenv("GEMINI_MODEL", "gemini-2.5-flash") | |
# Set page config consistent with other pages | |
st.set_page_config( | |
page_title="AI Financial Dashboard", | |
page_icon="📊", | |
layout="wide" | |
) | |
# Configure Gemini API | |
genai.configure(api_key=os.getenv("GEMINI_API_KEY")) | |
def initialize_state(): | |
if "initialized" in st.session_state: return | |
st.session_state.initialized = True | |
st.session_state.td_api = TwelveDataAPI(os.getenv("TWELVEDATA_API_KEY")) | |
st.session_state.stock_watchlist = {} | |
st.session_state.timeseries_cache = {} | |
st.session_state.active_timeseries_period = 'intraday' | |
st.session_state.currency_converter_state = {'from': 'USD', 'to': 'VND', 'amount': 100.0, 'result': None} | |
st.session_state.chat_history = [] | |
st.session_state.active_tab = 'Stock Watchlist' | |
st.session_state.chat_session = None | |
initialize_state() | |
# --- 2. LOAD BACKGROUND DATA --- | |
def load_market_data(): | |
td_api = st.session_state.td_api | |
stocks_data = td_api.get_all_stocks() | |
forex_data = td_api.get_forex_pairs() | |
forex_graph = {} | |
if forex_data and 'data' in forex_data: | |
for pair in forex_data['data']: | |
base, quote = pair['symbol'].split('/'); forex_graph.setdefault(base, []); forex_graph.setdefault(quote, []); forex_graph[base].append(quote); forex_graph[quote].append(base) | |
country_currency_map = {} | |
if stocks_data and 'data' in stocks_data: | |
for stock in stocks_data['data']: | |
country, currency = stock.get('country'), stock.get('currency') | |
if country and currency: country_currency_map[country.lower()] = currency | |
all_currencies = sorted(forex_graph.keys()) | |
return stocks_data, forex_graph, country_currency_map, all_currencies | |
ALL_STOCKS_CACHE, FOREX_GRAPH, COUNTRY_CURRENCY_MAP, AVAILABLE_CURRENCIES = load_market_data() | |
# --- 3. TOOL EXECUTION LOGIC --- | |
def find_and_process_stock(query: str): | |
print(f"Hybrid searching for stock: '{query}'...") | |
query_lower = query.lower() | |
found_data = [s for s in ALL_STOCKS_CACHE.get('data', []) if query_lower in s['symbol'].lower() or query_lower in s['name'].lower()] | |
if not found_data: | |
results = st.session_state.td_api.get_stocks(symbol=query) | |
found_data = results.get('data', []) | |
if len(found_data) == 1: | |
stock_info = found_data[0]; symbol = stock_info['symbol'] | |
st.session_state.stock_watchlist[symbol] = stock_info | |
ts_data = get_smart_time_series(symbol=symbol, time_period='intraday') | |
if 'values' in ts_data: | |
df = pd.DataFrame(ts_data['values']); df['datetime'] = pd.to_datetime(df['datetime']); df['close'] = pd.to_numeric(df['close']) | |
if symbol not in st.session_state.timeseries_cache: st.session_state.timeseries_cache[symbol] = {} | |
st.session_state.timeseries_cache[symbol]['intraday'] = df.sort_values('datetime').set_index('datetime') | |
st.session_state.active_tab = 'Time Charts'; st.session_state.active_timeseries_period = 'intraday' | |
return {"status": "SINGLE_STOCK_PROCESSED", "symbol": symbol, "name": stock_info.get('name', 'N/A')} | |
elif len(found_data) > 1: return {"status": "MULTIPLE_STOCKS_FOUND", "data": found_data[:5]} | |
else: return {"status": "NO_STOCKS_FOUND"} | |
def get_smart_time_series(symbol: str, time_period: str): | |
# Kiểm tra nếu symbol chưa có trong watchlist thì thêm vào trước | |
if symbol not in st.session_state.stock_watchlist: | |
# Tìm thông tin cổ phiếu và thêm vào watchlist | |
results = st.session_state.td_api.get_stocks(symbol=symbol) | |
found_data = results.get('data', []) | |
if found_data: | |
stock_info = found_data[0] | |
st.session_state.stock_watchlist[symbol] = stock_info | |
st.session_state.active_tab = 'Time Charts' | |
logic_map = {'intraday': {'interval': '15min', 'outputsize': 120}, '1_week': {'interval': '1h', 'outputsize': 40}, '1_month': {'interval': '1day', 'outputsize': 22}, '6_months': {'interval': '1day', 'outputsize': 120}, '1_year': {'interval': '1week', 'outputsize': 52}} | |
params = logic_map.get(time_period) | |
if not params: return {"error": f"Time period '{time_period}' is not valid."} | |
result = st.session_state.td_api.get_time_series(symbol=symbol, **params) | |
# Nếu kết quả thành công, cập nhật cache | |
if 'values' in result: | |
df = pd.DataFrame(result['values']) | |
df['datetime'] = pd.to_datetime(df['datetime']) | |
df['close'] = pd.to_numeric(df['close']) | |
if symbol not in st.session_state.timeseries_cache: | |
st.session_state.timeseries_cache[symbol] = {} | |
st.session_state.timeseries_cache[symbol][time_period] = df.sort_values('datetime').set_index('datetime') | |
return result | |
def find_conversion_path_bfs(start, end): | |
if start not in FOREX_GRAPH or end not in FOREX_GRAPH: return None | |
q = deque([(start, [start])]); visited = {start} | |
while q: | |
curr, path = q.popleft() | |
if curr == end: return path | |
for neighbor in FOREX_GRAPH.get(curr, []): | |
if neighbor not in visited: visited.add(neighbor); q.append((neighbor, path + [neighbor])) | |
return None | |
def convert_currency_with_bridge(amount: float, symbol: str): | |
try: start_currency, end_currency = symbol.upper().split('/') | |
except ValueError: return {"error": "Invalid currency pair format."} | |
path = find_conversion_path_bfs(start_currency, end_currency) | |
if not path: return {"error": f"No conversion path found from {start_currency} to {end_currency}."} | |
current_amount = amount; steps = [] | |
for i in range(len(path) - 1): | |
step_start, step_end = path[i], path[i+1] | |
result = st.session_state.td_api.currency_conversion(amount=current_amount, symbol=f"{step_start}/{step_end}") | |
if 'rate' in result and result.get('rate') is not None: | |
current_amount = result['amount']; steps.append({"step": f"{i+1}. {step_start} → {step_end}", "rate": result['rate'], "intermediate_amount": current_amount}) | |
else: | |
inverse_result = st.session_state.td_api.currency_conversion(amount=1, symbol=f"{step_end}/{step_start}") | |
if 'rate' in inverse_result and inverse_result.get('rate') and inverse_result['rate'] != 0: | |
rate = 1 / inverse_result['rate']; current_amount *= rate; steps.append({"step": f"{i+1}. {step_start} → {step_end} (Inverse)", "rate": rate, "intermediate_amount": current_amount}) | |
else: return {"error": f"Error in conversion step from {step_start} to {step_end}."} | |
return {"status": "Success", "original_amount": amount, "final_amount": current_amount, "path_taken": path, "conversion_steps": steps} | |
def perform_currency_conversion(amount: float, symbol: str): | |
result = convert_currency_with_bridge(amount, symbol) | |
st.session_state.currency_converter_state.update({'result': result, 'amount': amount}) | |
try: | |
from_curr, to_curr = symbol.split('/'); st.session_state.currency_converter_state.update({'from': from_curr, 'to': to_curr}) | |
except: pass | |
st.session_state.active_tab = 'Currency Converter' | |
return result | |
# --- 4. GEMINI CONFIGURATION --- | |
SYSTEM_INSTRUCTION = """You are the AI brain controlling an Interactive Financial Dashboard. Your task is to understand user requests, call appropriate tools, and report results concisely. | |
GOLDEN RULES: | |
1. **UNDERSTAND FIRST, CALL LATER:** | |
* **Company Name:** When a user enters a company name (e.g., "Vingroup Corporation", "Apple"), your FIRST task is to use the `find_and_process_stock` tool to identify the official stock symbol. | |
* **Stock Symbol:** When a user directly provides a stock symbol (e.g., "AAPL", "VNM"), use `find_and_process_stock` first to confirm and add it to the watchlist. | |
* **Time Period Request:** When user asks for specific time period (e.g., "last year", "last month"), first make sure the stock symbol is processed with `find_and_process_stock`, then use `get_smart_time_series` with appropriate time_period. | |
* **Country Name:** When a user enters a country name for currency (e.g., "Vietnamese currency"), you must infer the 3-letter currency code ("VND") BEFORE calling the `perform_currency_conversion` tool. | |
2. **ACT AND NOTIFY:** Your role is to execute commands and report briefly. | |
* **Found 1 symbol:** "I've found [Company Name] ([Symbol]) and automatically added it to your watchlist and chart." | |
* **Found multiple symbols:** "I found several results for '[query]'. Please specify which exact symbol you want to track?" | |
* **Currency conversion:** "Done. Please see detailed results in the 'Currency Converter' tab." | |
3. **NO DATA LISTING:** The dashboard already displays everything. ABSOLUTELY do not repeat lists, numbers, or raw data in your response. | |
""" | |
def get_model_and_tools(): | |
find_stock_func = glm.FunctionDeclaration(name="find_and_process_stock", description="Search for stock by symbol or name and automatically process. Use this tool FIRST to identify the official stock symbol.", parameters=glm.Schema(type=glm.Type.OBJECT, properties={'query': glm.Schema(type=glm.Type.STRING, description="Symbol or company name, e.g., 'Vingroup', 'Apple'.")}, required=['query'])) | |
get_ts_func = glm.FunctionDeclaration(name="get_smart_time_series", description="Get price history data after knowing the official stock symbol.", parameters=glm.Schema(type=glm.Type.OBJECT, properties={'symbol': glm.Schema(type=glm.Type.STRING), 'time_period': glm.Schema(type=glm.Type.STRING, enum=["intraday", "1_week", "1_month", "6_months", "1_year"])}, required=['symbol', 'time_period'])) | |
currency_func = glm.FunctionDeclaration(name="perform_currency_conversion", description="Convert currency after knowing the 3-letter code of source/target currency pair, e.g., USD/VND, JPY/EUR", parameters=glm.Schema(type=glm.Type.OBJECT, properties={'amount': glm.Schema(type=glm.Type.NUMBER), 'symbol': glm.Schema(type=glm.Type.STRING)}, required=['amount', 'symbol'])) | |
finance_tool = glm.Tool(function_declarations=[find_stock_func, get_ts_func, currency_func]) | |
model = genai.GenerativeModel(model_name=MODEL_NAME, tools=[finance_tool], system_instruction=SYSTEM_INSTRUCTION) | |
return model | |
model = get_model_and_tools() | |
if st.session_state.chat_session is None: | |
st.session_state.chat_session = model.start_chat(history=[]) | |
AVAILABLE_FUNCTIONS = {"find_and_process_stock": find_and_process_stock, "get_smart_time_series": get_smart_time_series, "perform_currency_conversion": perform_currency_conversion} | |
# --- 5. TAB DISPLAY LOGIC --- | |
def get_y_axis_domain(series: pd.Series, padding_percent: float = 0.1): | |
if series.empty: return None | |
data_min, data_max = series.min(), series.max() | |
if pd.isna(data_min) or pd.isna(data_max): return None | |
data_range = data_max - data_min | |
if data_range == 0: | |
padding = abs(data_max * (padding_percent / 2)) | |
return [data_min - padding, data_max + padding] | |
padding = data_range * padding_percent | |
return [data_min - padding, data_max + padding] | |
def render_watchlist_tab(): | |
st.subheader("Watchlist") | |
if not st.session_state.stock_watchlist: st.info("No stocks yet. Try searching for a symbol like 'Apple' or 'VNM'."); return | |
for symbol, stock_info in list(st.session_state.stock_watchlist.items()): | |
col1, col2, col3 = st.columns([4, 4, 1]) | |
with col1: st.markdown(f"**{symbol}**"); st.caption(stock_info.get('name', 'N/A')) | |
with col2: st.markdown(f"**{stock_info.get('exchange', 'N/A')}**"); st.caption(f"{stock_info.get('country', 'N/A')} - {stock_info.get('currency', 'N/A')}") | |
with col3: | |
if st.button("🗑️", key=f"delete_{symbol}", help=f"Delete {symbol}"): | |
st.session_state.stock_watchlist.pop(symbol, None); st.session_state.timeseries_cache.pop(symbol, None); st.rerun() | |
st.divider() | |
def render_timeseries_tab(): | |
st.subheader("Chart Analysis") | |
if not st.session_state.stock_watchlist: | |
st.info("Please add at least one stock to the watchlist to view charts."); return | |
time_periods = {'Intraday': 'intraday', '1 Week': '1_week', '1 Month': '1_month', '6 Months': '6_months', '1 Year': '1_year'} | |
period_keys = list(time_periods.keys()) | |
period_values = list(time_periods.values()) | |
default_index = period_values.index(st.session_state.active_timeseries_period) if st.session_state.active_timeseries_period in period_values else 0 | |
selected_label = st.radio("Select time period:", options=period_keys, horizontal=True, index=default_index) | |
selected_period = time_periods[selected_label] | |
if st.session_state.active_timeseries_period != selected_period: | |
st.session_state.active_timeseries_period = selected_period | |
with st.spinner(f"Updating charts..."): | |
for symbol in st.session_state.stock_watchlist.keys(): | |
ts_data = get_smart_time_series(symbol, selected_period) | |
if 'values' in ts_data: | |
df = pd.DataFrame(ts_data['values']); df['datetime'] = pd.to_datetime(df['datetime']); df['close'] = pd.to_numeric(df['close']) | |
if symbol not in st.session_state.timeseries_cache: st.session_state.timeseries_cache[symbol] = {} | |
st.session_state.timeseries_cache[symbol][selected_period] = df.sort_values('datetime').set_index('datetime') | |
st.rerun() | |
all_series_data = {symbol: st.session_state.timeseries_cache[symbol][selected_period] for symbol in st.session_state.stock_watchlist.keys() if symbol in st.session_state.timeseries_cache and selected_period in st.session_state.timeseries_cache[symbol]} | |
if not all_series_data: | |
st.warning("Not enough data for the selected time period."); return | |
st.markdown("##### Growth Performance Comparison (%)") | |
normalized_dfs = [] | |
for symbol, df in all_series_data.items(): | |
if not df.empty: | |
normalized_series = (df['close'] / df['close'].iloc[0]) * 100 | |
normalized_df = normalized_series.reset_index(); normalized_df.columns = ['datetime', 'value']; normalized_df['symbol'] = symbol | |
normalized_dfs.append(normalized_df) | |
if normalized_dfs: | |
full_normalized_df = pd.concat(normalized_dfs) | |
y_domain = get_y_axis_domain(full_normalized_df['value']) | |
chart = alt.Chart(full_normalized_df).mark_line().encode(x=alt.X('datetime:T', title='Time'), y=alt.Y('value:Q', scale=alt.Scale(domain=y_domain, zero=False), title='Growth (%)'), color=alt.Color('symbol:N', title='Symbol'), tooltip=[alt.Tooltip('symbol:N', title='Symbol'), alt.Tooltip('datetime:T', title='Time', format='%Y-%m-%d %H:%M'), alt.Tooltip('value:Q', title='Growth', format='.2f')]).interactive() | |
st.altair_chart(chart, use_container_width=True) | |
else: | |
st.warning("No data to draw growth chart.") | |
st.divider() | |
st.markdown("##### Actual Price Charts") | |
for symbol, df in all_series_data.items(): | |
stock_info = st.session_state.stock_watchlist.get(symbol, {}) | |
st.markdown(f"**{symbol}** ({stock_info.get('currency', 'N/A')})") | |
if not df.empty: | |
y_domain = get_y_axis_domain(df['close']) | |
data_for_chart = df.reset_index() | |
price_chart = alt.Chart(data_for_chart).mark_line().encode(x=alt.X('datetime:T', title='Time'), y=alt.Y('close:Q', scale=alt.Scale(domain=y_domain, zero=False), title='Price'), tooltip=[alt.Tooltip('datetime:T', title='Time', format='%Y-%m-%d %H:%M'), alt.Tooltip('close:Q', title='Price', format=',.2f')]).interactive() | |
st.altair_chart(price_chart, use_container_width=True) | |
def render_currency_tab(): | |
st.subheader("Currency Converter Tool"); state = st.session_state.currency_converter_state | |
col1, col2 = st.columns(2) | |
amount = col1.number_input("Amount", value=state['amount'], min_value=0.0, format="%.2f", key="conv_amount") | |
from_curr = col1.selectbox("From", options=AVAILABLE_CURRENCIES, index=AVAILABLE_CURRENCIES.index(state['from']) if state['from'] in AVAILABLE_CURRENCIES else 0, key="conv_from") | |
to_curr = col2.selectbox("To", options=AVAILABLE_CURRENCIES, index=AVAILABLE_CURRENCIES.index(state['to']) if state['to'] in AVAILABLE_CURRENCIES else 1, key="conv_to") | |
if st.button("Convert", use_container_width=True, key="conv_btn"): | |
with st.spinner("Converting..."): result = perform_currency_conversion(amount, f"{from_curr}/{to_curr}"); st.rerun() | |
if state['result']: | |
res = state['result'] | |
if res.get('status') == 'Success': st.success(f"**Result:** `{res['original_amount']:,.2f} {res['path_taken'][0]}` = `{res['final_amount']:,.2f} {res['path_taken'][-1]}`") | |
else: st.error(f"Error: {res.get('error', 'Unknown')}") | |
# --- 6. MAIN APP LAYOUT & CONTROL FLOW --- | |
st.title("📈 AI Financial Dashboard") | |
# Chia bố cục thành hai cột | |
col1, col2 = st.columns([1, 1]) | |
# Cột bên trái cho chat với AI | |
with col1: | |
chat_container = st.container(height=600) | |
with chat_container: | |
for message in st.session_state.chat_history: | |
with st.chat_message(message["role"]): | |
st.markdown(message["parts"]) | |
# Cột bên phải cho tab biểu đồ và dữ liệu | |
with col2: | |
right_column_container = st.container(height=600) | |
with right_column_container: | |
tab_names = ['Stock Watchlist', 'Time Charts', 'Currency Converter'] | |
try: default_index = tab_names.index(st.session_state.active_tab) | |
except ValueError: default_index = 0 | |
st.session_state.active_tab = tab_names[default_index] | |
tab1, tab2, tab3 = st.tabs(tab_names) | |
with tab1: render_watchlist_tab() | |
with tab2: render_timeseries_tab() | |
with tab3: render_currency_tab() | |
# Input chat nằm dưới cùng | |
user_prompt = st.chat_input("Ask AI to control the dashboard...") | |
if user_prompt: | |
st.session_state.chat_history.append({"role": "user", "parts": user_prompt}) | |
st.rerun() | |
# Xử lý câu hỏi của người dùng và hiển thị phản hồi AI | |
if st.session_state.chat_history and st.session_state.chat_history[-1]["role"] == "user": | |
last_user_prompt = st.session_state.chat_history[-1]["parts"] | |
with chat_container: | |
with st.chat_message("model"): | |
with st.spinner("🤖 AI executing command..."): | |
response = st.session_state.chat_session.send_message(last_user_prompt) | |
tool_calls = [part.function_call for part in response.candidates[0].content.parts if part.function_call] | |
while tool_calls: | |
tool_responses = [] | |
for call in tool_calls: | |
func_name = call.name; func_args = {k: v for k, v in call.args.items()} | |
if func_name in AVAILABLE_FUNCTIONS: | |
tool_result = AVAILABLE_FUNCTIONS[func_name](**func_args) | |
tool_responses.append(glm.Part(function_response=glm.FunctionResponse(name=func_name, response={'result': tool_result}))) | |
else: | |
tool_responses.append(glm.Part(function_response=glm.FunctionResponse(name=func_name, response={'error': f"Function '{func_name}' not found."}))) | |
response = st.session_state.chat_session.send_message(glm.Content(parts=tool_responses)) | |
tool_calls = [part.function_call for part in response.candidates[0].content.parts if part.function_call] | |
# Tìm kiếm từ khóa thời gian trong prompt của người dùng | |
old_period = st.session_state.active_timeseries_period | |
if last_user_prompt and "last year" in last_user_prompt.lower(): | |
st.session_state.active_timeseries_period = "1_year" | |
elif last_user_prompt and "last 6 months" in last_user_prompt.lower(): | |
st.session_state.active_timeseries_period = "6_months" | |
elif last_user_prompt and "last month" in last_user_prompt.lower(): | |
st.session_state.active_timeseries_period = "1_month" | |
elif last_user_prompt and "last week" in last_user_prompt.lower(): | |
st.session_state.active_timeseries_period = "1_week" | |
# Nếu thời gian thay đổi và có cổ phiếu trong watchlist, cập nhật dữ liệu | |
if old_period != st.session_state.active_timeseries_period and st.session_state.stock_watchlist: | |
new_period = st.session_state.active_timeseries_period | |
for symbol in st.session_state.stock_watchlist.keys(): | |
if symbol not in st.session_state.timeseries_cache or new_period not in st.session_state.timeseries_cache[symbol]: | |
ts_data = get_smart_time_series(symbol, new_period) | |
st.session_state.active_tab = 'Time Charts' | |
st.session_state.chat_history.append({"role": "model", "parts": response.text}) | |
st.rerun() |