# app.py (Final version with Advanced Charts) import streamlit as st import pandas as pd import altair as alt # <-- Add Altair library import google.generativeai as genai import google.ai.generativelanguage as glm from dotenv import load_dotenv import os from twelvedata_api import TwelveDataAPI from collections import deque from datetime import datetime # --- 1. INITIAL CONFIGURATION & STATE INITIALIZATION --- load_dotenv() MODEL_NAME = os.getenv("GEMINI_MODEL", "gemini-2.5-flash") # Set page config consistent with other pages st.set_page_config( page_title="AI Financial Dashboard", page_icon="📊", layout="wide" ) # Configure Gemini API genai.configure(api_key=os.getenv("GEMINI_API_KEY")) def initialize_state(): if "initialized" in st.session_state: return st.session_state.initialized = True st.session_state.td_api = TwelveDataAPI(os.getenv("TWELVEDATA_API_KEY")) st.session_state.stock_watchlist = {} st.session_state.timeseries_cache = {} st.session_state.active_timeseries_period = 'intraday' st.session_state.currency_converter_state = {'from': 'USD', 'to': 'VND', 'amount': 100.0, 'result': None} st.session_state.chat_history = [] st.session_state.active_tab = 'Stock Watchlist' st.session_state.chat_session = None initialize_state() # --- 2. LOAD BACKGROUND DATA --- @st.cache_data(show_spinner="Loading and preparing market data...") def load_market_data(): td_api = st.session_state.td_api stocks_data = td_api.get_all_stocks() forex_data = td_api.get_forex_pairs() forex_graph = {} if forex_data and 'data' in forex_data: for pair in forex_data['data']: base, quote = pair['symbol'].split('/'); forex_graph.setdefault(base, []); forex_graph.setdefault(quote, []); forex_graph[base].append(quote); forex_graph[quote].append(base) country_currency_map = {} if stocks_data and 'data' in stocks_data: for stock in stocks_data['data']: country, currency = stock.get('country'), stock.get('currency') if country and currency: country_currency_map[country.lower()] = currency all_currencies = sorted(forex_graph.keys()) return stocks_data, forex_graph, country_currency_map, all_currencies ALL_STOCKS_CACHE, FOREX_GRAPH, COUNTRY_CURRENCY_MAP, AVAILABLE_CURRENCIES = load_market_data() # --- 3. TOOL EXECUTION LOGIC --- def find_and_process_stock(query: str): print(f"Hybrid searching for stock: '{query}'...") query_lower = query.lower() found_data = [s for s in ALL_STOCKS_CACHE.get('data', []) if query_lower in s['symbol'].lower() or query_lower in s['name'].lower()] if not found_data: results = st.session_state.td_api.get_stocks(symbol=query) found_data = results.get('data', []) if len(found_data) == 1: stock_info = found_data[0]; symbol = stock_info['symbol'] st.session_state.stock_watchlist[symbol] = stock_info ts_data = get_smart_time_series(symbol=symbol, time_period='intraday') if 'values' in ts_data: df = pd.DataFrame(ts_data['values']); df['datetime'] = pd.to_datetime(df['datetime']); df['close'] = pd.to_numeric(df['close']) if symbol not in st.session_state.timeseries_cache: st.session_state.timeseries_cache[symbol] = {} st.session_state.timeseries_cache[symbol]['intraday'] = df.sort_values('datetime').set_index('datetime') st.session_state.active_tab = 'Time Charts'; st.session_state.active_timeseries_period = 'intraday' return {"status": "SINGLE_STOCK_PROCESSED", "symbol": symbol, "name": stock_info.get('name', 'N/A')} elif len(found_data) > 1: return {"status": "MULTIPLE_STOCKS_FOUND", "data": found_data[:5]} else: return {"status": "NO_STOCKS_FOUND"} def get_smart_time_series(symbol: str, time_period: str): # Kiểm tra nếu symbol chưa có trong watchlist thì thêm vào trước if symbol not in st.session_state.stock_watchlist: # Tìm thông tin cổ phiếu và thêm vào watchlist results = st.session_state.td_api.get_stocks(symbol=symbol) found_data = results.get('data', []) if found_data: stock_info = found_data[0] st.session_state.stock_watchlist[symbol] = stock_info st.session_state.active_tab = 'Time Charts' logic_map = {'intraday': {'interval': '15min', 'outputsize': 120}, '1_week': {'interval': '1h', 'outputsize': 40}, '1_month': {'interval': '1day', 'outputsize': 22}, '6_months': {'interval': '1day', 'outputsize': 120}, '1_year': {'interval': '1week', 'outputsize': 52}} params = logic_map.get(time_period) if not params: return {"error": f"Time period '{time_period}' is not valid."} result = st.session_state.td_api.get_time_series(symbol=symbol, **params) # Nếu kết quả thành công, cập nhật cache if 'values' in result: df = pd.DataFrame(result['values']) df['datetime'] = pd.to_datetime(df['datetime']) df['close'] = pd.to_numeric(df['close']) if symbol not in st.session_state.timeseries_cache: st.session_state.timeseries_cache[symbol] = {} st.session_state.timeseries_cache[symbol][time_period] = df.sort_values('datetime').set_index('datetime') return result def find_conversion_path_bfs(start, end): if start not in FOREX_GRAPH or end not in FOREX_GRAPH: return None q = deque([(start, [start])]); visited = {start} while q: curr, path = q.popleft() if curr == end: return path for neighbor in FOREX_GRAPH.get(curr, []): if neighbor not in visited: visited.add(neighbor); q.append((neighbor, path + [neighbor])) return None def convert_currency_with_bridge(amount: float, symbol: str): try: start_currency, end_currency = symbol.upper().split('/') except ValueError: return {"error": "Invalid currency pair format."} path = find_conversion_path_bfs(start_currency, end_currency) if not path: return {"error": f"No conversion path found from {start_currency} to {end_currency}."} current_amount = amount; steps = [] for i in range(len(path) - 1): step_start, step_end = path[i], path[i+1] result = st.session_state.td_api.currency_conversion(amount=current_amount, symbol=f"{step_start}/{step_end}") if 'rate' in result and result.get('rate') is not None: current_amount = result['amount']; steps.append({"step": f"{i+1}. {step_start} → {step_end}", "rate": result['rate'], "intermediate_amount": current_amount}) else: inverse_result = st.session_state.td_api.currency_conversion(amount=1, symbol=f"{step_end}/{step_start}") if 'rate' in inverse_result and inverse_result.get('rate') and inverse_result['rate'] != 0: rate = 1 / inverse_result['rate']; current_amount *= rate; steps.append({"step": f"{i+1}. {step_start} → {step_end} (Inverse)", "rate": rate, "intermediate_amount": current_amount}) else: return {"error": f"Error in conversion step from {step_start} to {step_end}."} return {"status": "Success", "original_amount": amount, "final_amount": current_amount, "path_taken": path, "conversion_steps": steps} def perform_currency_conversion(amount: float, symbol: str): result = convert_currency_with_bridge(amount, symbol) st.session_state.currency_converter_state.update({'result': result, 'amount': amount}) try: from_curr, to_curr = symbol.split('/'); st.session_state.currency_converter_state.update({'from': from_curr, 'to': to_curr}) except: pass st.session_state.active_tab = 'Currency Converter' return result # --- 4. GEMINI CONFIGURATION --- SYSTEM_INSTRUCTION = """You are the AI brain controlling an Interactive Financial Dashboard. Your task is to understand user requests, call appropriate tools, and report results concisely. GOLDEN RULES: 1. **UNDERSTAND FIRST, CALL LATER:** * **Company Name:** When a user enters a company name (e.g., "Vingroup Corporation", "Apple"), your FIRST task is to use the `find_and_process_stock` tool to identify the official stock symbol. * **Stock Symbol:** When a user directly provides a stock symbol (e.g., "AAPL", "VNM"), use `find_and_process_stock` first to confirm and add it to the watchlist. * **Time Period Request:** When user asks for specific time period (e.g., "last year", "last month"), first make sure the stock symbol is processed with `find_and_process_stock`, then use `get_smart_time_series` with appropriate time_period. * **Country Name:** When a user enters a country name for currency (e.g., "Vietnamese currency"), you must infer the 3-letter currency code ("VND") BEFORE calling the `perform_currency_conversion` tool. 2. **ACT AND NOTIFY:** Your role is to execute commands and report briefly. * **Found 1 symbol:** "I've found [Company Name] ([Symbol]) and automatically added it to your watchlist and chart." * **Found multiple symbols:** "I found several results for '[query]'. Please specify which exact symbol you want to track?" * **Currency conversion:** "Done. Please see detailed results in the 'Currency Converter' tab." 3. **NO DATA LISTING:** The dashboard already displays everything. ABSOLUTELY do not repeat lists, numbers, or raw data in your response. """ @st.cache_resource def get_model_and_tools(): find_stock_func = glm.FunctionDeclaration(name="find_and_process_stock", description="Search for stock by symbol or name and automatically process. Use this tool FIRST to identify the official stock symbol.", parameters=glm.Schema(type=glm.Type.OBJECT, properties={'query': glm.Schema(type=glm.Type.STRING, description="Symbol or company name, e.g., 'Vingroup', 'Apple'.")}, required=['query'])) get_ts_func = glm.FunctionDeclaration(name="get_smart_time_series", description="Get price history data after knowing the official stock symbol.", parameters=glm.Schema(type=glm.Type.OBJECT, properties={'symbol': glm.Schema(type=glm.Type.STRING), 'time_period': glm.Schema(type=glm.Type.STRING, enum=["intraday", "1_week", "1_month", "6_months", "1_year"])}, required=['symbol', 'time_period'])) currency_func = glm.FunctionDeclaration(name="perform_currency_conversion", description="Convert currency after knowing the 3-letter code of source/target currency pair, e.g., USD/VND, JPY/EUR", parameters=glm.Schema(type=glm.Type.OBJECT, properties={'amount': glm.Schema(type=glm.Type.NUMBER), 'symbol': glm.Schema(type=glm.Type.STRING)}, required=['amount', 'symbol'])) finance_tool = glm.Tool(function_declarations=[find_stock_func, get_ts_func, currency_func]) model = genai.GenerativeModel(model_name=MODEL_NAME, tools=[finance_tool], system_instruction=SYSTEM_INSTRUCTION) return model model = get_model_and_tools() if st.session_state.chat_session is None: st.session_state.chat_session = model.start_chat(history=[]) AVAILABLE_FUNCTIONS = {"find_and_process_stock": find_and_process_stock, "get_smart_time_series": get_smart_time_series, "perform_currency_conversion": perform_currency_conversion} # --- 5. TAB DISPLAY LOGIC --- def get_y_axis_domain(series: pd.Series, padding_percent: float = 0.1): if series.empty: return None data_min, data_max = series.min(), series.max() if pd.isna(data_min) or pd.isna(data_max): return None data_range = data_max - data_min if data_range == 0: padding = abs(data_max * (padding_percent / 2)) return [data_min - padding, data_max + padding] padding = data_range * padding_percent return [data_min - padding, data_max + padding] def render_watchlist_tab(): st.subheader("Watchlist") if not st.session_state.stock_watchlist: st.info("No stocks yet. Try searching for a symbol like 'Apple' or 'VNM'."); return for symbol, stock_info in list(st.session_state.stock_watchlist.items()): col1, col2, col3 = st.columns([4, 4, 1]) with col1: st.markdown(f"**{symbol}**"); st.caption(stock_info.get('name', 'N/A')) with col2: st.markdown(f"**{stock_info.get('exchange', 'N/A')}**"); st.caption(f"{stock_info.get('country', 'N/A')} - {stock_info.get('currency', 'N/A')}") with col3: if st.button("🗑️", key=f"delete_{symbol}", help=f"Delete {symbol}"): st.session_state.stock_watchlist.pop(symbol, None); st.session_state.timeseries_cache.pop(symbol, None); st.rerun() st.divider() def render_timeseries_tab(): st.subheader("Chart Analysis") if not st.session_state.stock_watchlist: st.info("Please add at least one stock to the watchlist to view charts."); return time_periods = {'Intraday': 'intraday', '1 Week': '1_week', '1 Month': '1_month', '6 Months': '6_months', '1 Year': '1_year'} period_keys = list(time_periods.keys()) period_values = list(time_periods.values()) default_index = period_values.index(st.session_state.active_timeseries_period) if st.session_state.active_timeseries_period in period_values else 0 selected_label = st.radio("Select time period:", options=period_keys, horizontal=True, index=default_index) selected_period = time_periods[selected_label] if st.session_state.active_timeseries_period != selected_period: st.session_state.active_timeseries_period = selected_period with st.spinner(f"Updating charts..."): for symbol in st.session_state.stock_watchlist.keys(): ts_data = get_smart_time_series(symbol, selected_period) if 'values' in ts_data: df = pd.DataFrame(ts_data['values']); df['datetime'] = pd.to_datetime(df['datetime']); df['close'] = pd.to_numeric(df['close']) if symbol not in st.session_state.timeseries_cache: st.session_state.timeseries_cache[symbol] = {} st.session_state.timeseries_cache[symbol][selected_period] = df.sort_values('datetime').set_index('datetime') st.rerun() all_series_data = {symbol: st.session_state.timeseries_cache[symbol][selected_period] for symbol in st.session_state.stock_watchlist.keys() if symbol in st.session_state.timeseries_cache and selected_period in st.session_state.timeseries_cache[symbol]} if not all_series_data: st.warning("Not enough data for the selected time period."); return st.markdown("##### Growth Performance Comparison (%)") normalized_dfs = [] for symbol, df in all_series_data.items(): if not df.empty: normalized_series = (df['close'] / df['close'].iloc[0]) * 100 normalized_df = normalized_series.reset_index(); normalized_df.columns = ['datetime', 'value']; normalized_df['symbol'] = symbol normalized_dfs.append(normalized_df) if normalized_dfs: full_normalized_df = pd.concat(normalized_dfs) y_domain = get_y_axis_domain(full_normalized_df['value']) chart = alt.Chart(full_normalized_df).mark_line().encode(x=alt.X('datetime:T', title='Time'), y=alt.Y('value:Q', scale=alt.Scale(domain=y_domain, zero=False), title='Growth (%)'), color=alt.Color('symbol:N', title='Symbol'), tooltip=[alt.Tooltip('symbol:N', title='Symbol'), alt.Tooltip('datetime:T', title='Time', format='%Y-%m-%d %H:%M'), alt.Tooltip('value:Q', title='Growth', format='.2f')]).interactive() st.altair_chart(chart, use_container_width=True) else: st.warning("No data to draw growth chart.") st.divider() st.markdown("##### Actual Price Charts") for symbol, df in all_series_data.items(): stock_info = st.session_state.stock_watchlist.get(symbol, {}) st.markdown(f"**{symbol}** ({stock_info.get('currency', 'N/A')})") if not df.empty: y_domain = get_y_axis_domain(df['close']) data_for_chart = df.reset_index() price_chart = alt.Chart(data_for_chart).mark_line().encode(x=alt.X('datetime:T', title='Time'), y=alt.Y('close:Q', scale=alt.Scale(domain=y_domain, zero=False), title='Price'), tooltip=[alt.Tooltip('datetime:T', title='Time', format='%Y-%m-%d %H:%M'), alt.Tooltip('close:Q', title='Price', format=',.2f')]).interactive() st.altair_chart(price_chart, use_container_width=True) def render_currency_tab(): st.subheader("Currency Converter Tool"); state = st.session_state.currency_converter_state col1, col2 = st.columns(2) amount = col1.number_input("Amount", value=state['amount'], min_value=0.0, format="%.2f", key="conv_amount") from_curr = col1.selectbox("From", options=AVAILABLE_CURRENCIES, index=AVAILABLE_CURRENCIES.index(state['from']) if state['from'] in AVAILABLE_CURRENCIES else 0, key="conv_from") to_curr = col2.selectbox("To", options=AVAILABLE_CURRENCIES, index=AVAILABLE_CURRENCIES.index(state['to']) if state['to'] in AVAILABLE_CURRENCIES else 1, key="conv_to") if st.button("Convert", use_container_width=True, key="conv_btn"): with st.spinner("Converting..."): result = perform_currency_conversion(amount, f"{from_curr}/{to_curr}"); st.rerun() if state['result']: res = state['result'] if res.get('status') == 'Success': st.success(f"**Result:** `{res['original_amount']:,.2f} {res['path_taken'][0]}` = `{res['final_amount']:,.2f} {res['path_taken'][-1]}`") else: st.error(f"Error: {res.get('error', 'Unknown')}") # --- 6. MAIN APP LAYOUT & CONTROL FLOW --- st.title("📈 AI Financial Dashboard") # Chia bố cục thành hai cột col1, col2 = st.columns([1, 1]) # Cột bên trái cho chat với AI with col1: chat_container = st.container(height=600) with chat_container: for message in st.session_state.chat_history: with st.chat_message(message["role"]): st.markdown(message["parts"]) # Cột bên phải cho tab biểu đồ và dữ liệu with col2: right_column_container = st.container(height=600) with right_column_container: tab_names = ['Stock Watchlist', 'Time Charts', 'Currency Converter'] try: default_index = tab_names.index(st.session_state.active_tab) except ValueError: default_index = 0 st.session_state.active_tab = tab_names[default_index] tab1, tab2, tab3 = st.tabs(tab_names) with tab1: render_watchlist_tab() with tab2: render_timeseries_tab() with tab3: render_currency_tab() # Input chat nằm dưới cùng user_prompt = st.chat_input("Ask AI to control the dashboard...") if user_prompt: st.session_state.chat_history.append({"role": "user", "parts": user_prompt}) st.rerun() # Xử lý câu hỏi của người dùng và hiển thị phản hồi AI if st.session_state.chat_history and st.session_state.chat_history[-1]["role"] == "user": last_user_prompt = st.session_state.chat_history[-1]["parts"] with chat_container: with st.chat_message("model"): with st.spinner("🤖 AI executing command..."): response = st.session_state.chat_session.send_message(last_user_prompt) tool_calls = [part.function_call for part in response.candidates[0].content.parts if part.function_call] while tool_calls: tool_responses = [] for call in tool_calls: func_name = call.name; func_args = {k: v for k, v in call.args.items()} if func_name in AVAILABLE_FUNCTIONS: tool_result = AVAILABLE_FUNCTIONS[func_name](**func_args) tool_responses.append(glm.Part(function_response=glm.FunctionResponse(name=func_name, response={'result': tool_result}))) else: tool_responses.append(glm.Part(function_response=glm.FunctionResponse(name=func_name, response={'error': f"Function '{func_name}' not found."}))) response = st.session_state.chat_session.send_message(glm.Content(parts=tool_responses)) tool_calls = [part.function_call for part in response.candidates[0].content.parts if part.function_call] # Tìm kiếm từ khóa thời gian trong prompt của người dùng old_period = st.session_state.active_timeseries_period if last_user_prompt and "last year" in last_user_prompt.lower(): st.session_state.active_timeseries_period = "1_year" elif last_user_prompt and "last 6 months" in last_user_prompt.lower(): st.session_state.active_timeseries_period = "6_months" elif last_user_prompt and "last month" in last_user_prompt.lower(): st.session_state.active_timeseries_period = "1_month" elif last_user_prompt and "last week" in last_user_prompt.lower(): st.session_state.active_timeseries_period = "1_week" # Nếu thời gian thay đổi và có cổ phiếu trong watchlist, cập nhật dữ liệu if old_period != st.session_state.active_timeseries_period and st.session_state.stock_watchlist: new_period = st.session_state.active_timeseries_period for symbol in st.session_state.stock_watchlist.keys(): if symbol not in st.session_state.timeseries_cache or new_period not in st.session_state.timeseries_cache[symbol]: ts_data = get_smart_time_series(symbol, new_period) st.session_state.active_tab = 'Time Charts' st.session_state.chat_history.append({"role": "model", "parts": response.text}) st.rerun()