Spaces:
Sleeping
Sleeping
File size: 21,755 Bytes
dec9e8e f63d7fd dec9e8e f63d7fd dec9e8e f63d7fd 44f095f 57d91e6 e48425a 57d91e6 f63d7fd dec9e8e f63d7fd dec9e8e f63d7fd dec9e8e f63d7fd dec9e8e f63d7fd e48425a f63d7fd dec9e8e e48425a f63d7fd dec9e8e f63d7fd dec9e8e f63d7fd dec9e8e f63d7fd dec9e8e f63d7fd dec9e8e f63d7fd dec9e8e e48425a dec9e8e f63d7fd dec9e8e f63d7fd 44f095f f63d7fd dec9e8e f63d7fd dec9e8e f63d7fd dec9e8e f63d7fd dec9e8e f63d7fd dec9e8e f63d7fd dec9e8e f63d7fd dec9e8e f63d7fd dec9e8e f63d7fd dec9e8e f63d7fd dec9e8e f63d7fd dec9e8e f63d7fd dec9e8e f63d7fd dec9e8e f63d7fd dec9e8e f63d7fd dec9e8e f63d7fd e48425a f63d7fd e48425a f63d7fd e48425a f63d7fd e48425a f63d7fd e48425a 57d91e6 e48425a f63d7fd e48425a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 |
# app.py (Final version with Advanced Charts)
import streamlit as st
import pandas as pd
import altair as alt # <-- Add Altair library
import google.generativeai as genai
import google.ai.generativelanguage as glm
from dotenv import load_dotenv
import os
from twelvedata_api import TwelveDataAPI
from collections import deque
from datetime import datetime
# --- 1. INITIAL CONFIGURATION & STATE INITIALIZATION ---
load_dotenv()
MODEL_NAME = os.getenv("GEMINI_MODEL", "gemini-2.5-flash")
# Set page config consistent with other pages
st.set_page_config(
page_title="AI Financial Dashboard",
page_icon="📊",
layout="wide"
)
# Configure Gemini API
genai.configure(api_key=os.getenv("GEMINI_API_KEY"))
def initialize_state():
if "initialized" in st.session_state: return
st.session_state.initialized = True
st.session_state.td_api = TwelveDataAPI(os.getenv("TWELVEDATA_API_KEY"))
st.session_state.stock_watchlist = {}
st.session_state.timeseries_cache = {}
st.session_state.active_timeseries_period = 'intraday'
st.session_state.currency_converter_state = {'from': 'USD', 'to': 'VND', 'amount': 100.0, 'result': None}
st.session_state.chat_history = []
st.session_state.active_tab = 'Stock Watchlist'
st.session_state.chat_session = None
initialize_state()
# --- 2. LOAD BACKGROUND DATA ---
@st.cache_data(show_spinner="Loading and preparing market data...")
def load_market_data():
td_api = st.session_state.td_api
stocks_data = td_api.get_all_stocks()
forex_data = td_api.get_forex_pairs()
forex_graph = {}
if forex_data and 'data' in forex_data:
for pair in forex_data['data']:
base, quote = pair['symbol'].split('/'); forex_graph.setdefault(base, []); forex_graph.setdefault(quote, []); forex_graph[base].append(quote); forex_graph[quote].append(base)
country_currency_map = {}
if stocks_data and 'data' in stocks_data:
for stock in stocks_data['data']:
country, currency = stock.get('country'), stock.get('currency')
if country and currency: country_currency_map[country.lower()] = currency
all_currencies = sorted(forex_graph.keys())
return stocks_data, forex_graph, country_currency_map, all_currencies
ALL_STOCKS_CACHE, FOREX_GRAPH, COUNTRY_CURRENCY_MAP, AVAILABLE_CURRENCIES = load_market_data()
# --- 3. TOOL EXECUTION LOGIC ---
def find_and_process_stock(query: str):
print(f"Hybrid searching for stock: '{query}'...")
query_lower = query.lower()
found_data = [s for s in ALL_STOCKS_CACHE.get('data', []) if query_lower in s['symbol'].lower() or query_lower in s['name'].lower()]
if not found_data:
results = st.session_state.td_api.get_stocks(symbol=query)
found_data = results.get('data', [])
if len(found_data) == 1:
stock_info = found_data[0]; symbol = stock_info['symbol']
st.session_state.stock_watchlist[symbol] = stock_info
ts_data = get_smart_time_series(symbol=symbol, time_period='intraday')
if 'values' in ts_data:
df = pd.DataFrame(ts_data['values']); df['datetime'] = pd.to_datetime(df['datetime']); df['close'] = pd.to_numeric(df['close'])
if symbol not in st.session_state.timeseries_cache: st.session_state.timeseries_cache[symbol] = {}
st.session_state.timeseries_cache[symbol]['intraday'] = df.sort_values('datetime').set_index('datetime')
st.session_state.active_tab = 'Time Charts'; st.session_state.active_timeseries_period = 'intraday'
return {"status": "SINGLE_STOCK_PROCESSED", "symbol": symbol, "name": stock_info.get('name', 'N/A')}
elif len(found_data) > 1: return {"status": "MULTIPLE_STOCKS_FOUND", "data": found_data[:5]}
else: return {"status": "NO_STOCKS_FOUND"}
def get_smart_time_series(symbol: str, time_period: str):
# Kiểm tra nếu symbol chưa có trong watchlist thì thêm vào trước
if symbol not in st.session_state.stock_watchlist:
# Tìm thông tin cổ phiếu và thêm vào watchlist
results = st.session_state.td_api.get_stocks(symbol=symbol)
found_data = results.get('data', [])
if found_data:
stock_info = found_data[0]
st.session_state.stock_watchlist[symbol] = stock_info
st.session_state.active_tab = 'Time Charts'
logic_map = {'intraday': {'interval': '15min', 'outputsize': 120}, '1_week': {'interval': '1h', 'outputsize': 40}, '1_month': {'interval': '1day', 'outputsize': 22}, '6_months': {'interval': '1day', 'outputsize': 120}, '1_year': {'interval': '1week', 'outputsize': 52}}
params = logic_map.get(time_period)
if not params: return {"error": f"Time period '{time_period}' is not valid."}
result = st.session_state.td_api.get_time_series(symbol=symbol, **params)
# Nếu kết quả thành công, cập nhật cache
if 'values' in result:
df = pd.DataFrame(result['values'])
df['datetime'] = pd.to_datetime(df['datetime'])
df['close'] = pd.to_numeric(df['close'])
if symbol not in st.session_state.timeseries_cache:
st.session_state.timeseries_cache[symbol] = {}
st.session_state.timeseries_cache[symbol][time_period] = df.sort_values('datetime').set_index('datetime')
return result
def find_conversion_path_bfs(start, end):
if start not in FOREX_GRAPH or end not in FOREX_GRAPH: return None
q = deque([(start, [start])]); visited = {start}
while q:
curr, path = q.popleft()
if curr == end: return path
for neighbor in FOREX_GRAPH.get(curr, []):
if neighbor not in visited: visited.add(neighbor); q.append((neighbor, path + [neighbor]))
return None
def convert_currency_with_bridge(amount: float, symbol: str):
try: start_currency, end_currency = symbol.upper().split('/')
except ValueError: return {"error": "Invalid currency pair format."}
path = find_conversion_path_bfs(start_currency, end_currency)
if not path: return {"error": f"No conversion path found from {start_currency} to {end_currency}."}
current_amount = amount; steps = []
for i in range(len(path) - 1):
step_start, step_end = path[i], path[i+1]
result = st.session_state.td_api.currency_conversion(amount=current_amount, symbol=f"{step_start}/{step_end}")
if 'rate' in result and result.get('rate') is not None:
current_amount = result['amount']; steps.append({"step": f"{i+1}. {step_start} → {step_end}", "rate": result['rate'], "intermediate_amount": current_amount})
else:
inverse_result = st.session_state.td_api.currency_conversion(amount=1, symbol=f"{step_end}/{step_start}")
if 'rate' in inverse_result and inverse_result.get('rate') and inverse_result['rate'] != 0:
rate = 1 / inverse_result['rate']; current_amount *= rate; steps.append({"step": f"{i+1}. {step_start} → {step_end} (Inverse)", "rate": rate, "intermediate_amount": current_amount})
else: return {"error": f"Error in conversion step from {step_start} to {step_end}."}
return {"status": "Success", "original_amount": amount, "final_amount": current_amount, "path_taken": path, "conversion_steps": steps}
def perform_currency_conversion(amount: float, symbol: str):
result = convert_currency_with_bridge(amount, symbol)
st.session_state.currency_converter_state.update({'result': result, 'amount': amount})
try:
from_curr, to_curr = symbol.split('/'); st.session_state.currency_converter_state.update({'from': from_curr, 'to': to_curr})
except: pass
st.session_state.active_tab = 'Currency Converter'
return result
# --- 4. GEMINI CONFIGURATION ---
SYSTEM_INSTRUCTION = """You are the AI brain controlling an Interactive Financial Dashboard. Your task is to understand user requests, call appropriate tools, and report results concisely.
GOLDEN RULES:
1. **UNDERSTAND FIRST, CALL LATER:**
* **Company Name:** When a user enters a company name (e.g., "Vingroup Corporation", "Apple"), your FIRST task is to use the `find_and_process_stock` tool to identify the official stock symbol.
* **Stock Symbol:** When a user directly provides a stock symbol (e.g., "AAPL", "VNM"), use `find_and_process_stock` first to confirm and add it to the watchlist.
* **Time Period Request:** When user asks for specific time period (e.g., "last year", "last month"), first make sure the stock symbol is processed with `find_and_process_stock`, then use `get_smart_time_series` with appropriate time_period.
* **Country Name:** When a user enters a country name for currency (e.g., "Vietnamese currency"), you must infer the 3-letter currency code ("VND") BEFORE calling the `perform_currency_conversion` tool.
2. **ACT AND NOTIFY:** Your role is to execute commands and report briefly.
* **Found 1 symbol:** "I've found [Company Name] ([Symbol]) and automatically added it to your watchlist and chart."
* **Found multiple symbols:** "I found several results for '[query]'. Please specify which exact symbol you want to track?"
* **Currency conversion:** "Done. Please see detailed results in the 'Currency Converter' tab."
3. **NO DATA LISTING:** The dashboard already displays everything. ABSOLUTELY do not repeat lists, numbers, or raw data in your response.
"""
@st.cache_resource
def get_model_and_tools():
find_stock_func = glm.FunctionDeclaration(name="find_and_process_stock", description="Search for stock by symbol or name and automatically process. Use this tool FIRST to identify the official stock symbol.", parameters=glm.Schema(type=glm.Type.OBJECT, properties={'query': glm.Schema(type=glm.Type.STRING, description="Symbol or company name, e.g., 'Vingroup', 'Apple'.")}, required=['query']))
get_ts_func = glm.FunctionDeclaration(name="get_smart_time_series", description="Get price history data after knowing the official stock symbol.", parameters=glm.Schema(type=glm.Type.OBJECT, properties={'symbol': glm.Schema(type=glm.Type.STRING), 'time_period': glm.Schema(type=glm.Type.STRING, enum=["intraday", "1_week", "1_month", "6_months", "1_year"])}, required=['symbol', 'time_period']))
currency_func = glm.FunctionDeclaration(name="perform_currency_conversion", description="Convert currency after knowing the 3-letter code of source/target currency pair, e.g., USD/VND, JPY/EUR", parameters=glm.Schema(type=glm.Type.OBJECT, properties={'amount': glm.Schema(type=glm.Type.NUMBER), 'symbol': glm.Schema(type=glm.Type.STRING)}, required=['amount', 'symbol']))
finance_tool = glm.Tool(function_declarations=[find_stock_func, get_ts_func, currency_func])
model = genai.GenerativeModel(model_name=MODEL_NAME, tools=[finance_tool], system_instruction=SYSTEM_INSTRUCTION)
return model
model = get_model_and_tools()
if st.session_state.chat_session is None:
st.session_state.chat_session = model.start_chat(history=[])
AVAILABLE_FUNCTIONS = {"find_and_process_stock": find_and_process_stock, "get_smart_time_series": get_smart_time_series, "perform_currency_conversion": perform_currency_conversion}
# --- 5. TAB DISPLAY LOGIC ---
def get_y_axis_domain(series: pd.Series, padding_percent: float = 0.1):
if series.empty: return None
data_min, data_max = series.min(), series.max()
if pd.isna(data_min) or pd.isna(data_max): return None
data_range = data_max - data_min
if data_range == 0:
padding = abs(data_max * (padding_percent / 2))
return [data_min - padding, data_max + padding]
padding = data_range * padding_percent
return [data_min - padding, data_max + padding]
def render_watchlist_tab():
st.subheader("Watchlist")
if not st.session_state.stock_watchlist: st.info("No stocks yet. Try searching for a symbol like 'Apple' or 'VNM'."); return
for symbol, stock_info in list(st.session_state.stock_watchlist.items()):
col1, col2, col3 = st.columns([4, 4, 1])
with col1: st.markdown(f"**{symbol}**"); st.caption(stock_info.get('name', 'N/A'))
with col2: st.markdown(f"**{stock_info.get('exchange', 'N/A')}**"); st.caption(f"{stock_info.get('country', 'N/A')} - {stock_info.get('currency', 'N/A')}")
with col3:
if st.button("🗑️", key=f"delete_{symbol}", help=f"Delete {symbol}"):
st.session_state.stock_watchlist.pop(symbol, None); st.session_state.timeseries_cache.pop(symbol, None); st.rerun()
st.divider()
def render_timeseries_tab():
st.subheader("Chart Analysis")
if not st.session_state.stock_watchlist:
st.info("Please add at least one stock to the watchlist to view charts."); return
time_periods = {'Intraday': 'intraday', '1 Week': '1_week', '1 Month': '1_month', '6 Months': '6_months', '1 Year': '1_year'}
period_keys = list(time_periods.keys())
period_values = list(time_periods.values())
default_index = period_values.index(st.session_state.active_timeseries_period) if st.session_state.active_timeseries_period in period_values else 0
selected_label = st.radio("Select time period:", options=period_keys, horizontal=True, index=default_index)
selected_period = time_periods[selected_label]
if st.session_state.active_timeseries_period != selected_period:
st.session_state.active_timeseries_period = selected_period
with st.spinner(f"Updating charts..."):
for symbol in st.session_state.stock_watchlist.keys():
ts_data = get_smart_time_series(symbol, selected_period)
if 'values' in ts_data:
df = pd.DataFrame(ts_data['values']); df['datetime'] = pd.to_datetime(df['datetime']); df['close'] = pd.to_numeric(df['close'])
if symbol not in st.session_state.timeseries_cache: st.session_state.timeseries_cache[symbol] = {}
st.session_state.timeseries_cache[symbol][selected_period] = df.sort_values('datetime').set_index('datetime')
st.rerun()
all_series_data = {symbol: st.session_state.timeseries_cache[symbol][selected_period] for symbol in st.session_state.stock_watchlist.keys() if symbol in st.session_state.timeseries_cache and selected_period in st.session_state.timeseries_cache[symbol]}
if not all_series_data:
st.warning("Not enough data for the selected time period."); return
st.markdown("##### Growth Performance Comparison (%)")
normalized_dfs = []
for symbol, df in all_series_data.items():
if not df.empty:
normalized_series = (df['close'] / df['close'].iloc[0]) * 100
normalized_df = normalized_series.reset_index(); normalized_df.columns = ['datetime', 'value']; normalized_df['symbol'] = symbol
normalized_dfs.append(normalized_df)
if normalized_dfs:
full_normalized_df = pd.concat(normalized_dfs)
y_domain = get_y_axis_domain(full_normalized_df['value'])
chart = alt.Chart(full_normalized_df).mark_line().encode(x=alt.X('datetime:T', title='Time'), y=alt.Y('value:Q', scale=alt.Scale(domain=y_domain, zero=False), title='Growth (%)'), color=alt.Color('symbol:N', title='Symbol'), tooltip=[alt.Tooltip('symbol:N', title='Symbol'), alt.Tooltip('datetime:T', title='Time', format='%Y-%m-%d %H:%M'), alt.Tooltip('value:Q', title='Growth', format='.2f')]).interactive()
st.altair_chart(chart, use_container_width=True)
else:
st.warning("No data to draw growth chart.")
st.divider()
st.markdown("##### Actual Price Charts")
for symbol, df in all_series_data.items():
stock_info = st.session_state.stock_watchlist.get(symbol, {})
st.markdown(f"**{symbol}** ({stock_info.get('currency', 'N/A')})")
if not df.empty:
y_domain = get_y_axis_domain(df['close'])
data_for_chart = df.reset_index()
price_chart = alt.Chart(data_for_chart).mark_line().encode(x=alt.X('datetime:T', title='Time'), y=alt.Y('close:Q', scale=alt.Scale(domain=y_domain, zero=False), title='Price'), tooltip=[alt.Tooltip('datetime:T', title='Time', format='%Y-%m-%d %H:%M'), alt.Tooltip('close:Q', title='Price', format=',.2f')]).interactive()
st.altair_chart(price_chart, use_container_width=True)
def render_currency_tab():
st.subheader("Currency Converter Tool"); state = st.session_state.currency_converter_state
col1, col2 = st.columns(2)
amount = col1.number_input("Amount", value=state['amount'], min_value=0.0, format="%.2f", key="conv_amount")
from_curr = col1.selectbox("From", options=AVAILABLE_CURRENCIES, index=AVAILABLE_CURRENCIES.index(state['from']) if state['from'] in AVAILABLE_CURRENCIES else 0, key="conv_from")
to_curr = col2.selectbox("To", options=AVAILABLE_CURRENCIES, index=AVAILABLE_CURRENCIES.index(state['to']) if state['to'] in AVAILABLE_CURRENCIES else 1, key="conv_to")
if st.button("Convert", use_container_width=True, key="conv_btn"):
with st.spinner("Converting..."): result = perform_currency_conversion(amount, f"{from_curr}/{to_curr}"); st.rerun()
if state['result']:
res = state['result']
if res.get('status') == 'Success': st.success(f"**Result:** `{res['original_amount']:,.2f} {res['path_taken'][0]}` = `{res['final_amount']:,.2f} {res['path_taken'][-1]}`")
else: st.error(f"Error: {res.get('error', 'Unknown')}")
# --- 6. MAIN APP LAYOUT & CONTROL FLOW ---
st.title("📈 AI Financial Dashboard")
# Chia bố cục thành hai cột
col1, col2 = st.columns([1, 1])
# Cột bên trái cho chat với AI
with col1:
chat_container = st.container(height=600)
with chat_container:
for message in st.session_state.chat_history:
with st.chat_message(message["role"]):
st.markdown(message["parts"])
# Cột bên phải cho tab biểu đồ và dữ liệu
with col2:
right_column_container = st.container(height=600)
with right_column_container:
tab_names = ['Stock Watchlist', 'Time Charts', 'Currency Converter']
try: default_index = tab_names.index(st.session_state.active_tab)
except ValueError: default_index = 0
st.session_state.active_tab = tab_names[default_index]
tab1, tab2, tab3 = st.tabs(tab_names)
with tab1: render_watchlist_tab()
with tab2: render_timeseries_tab()
with tab3: render_currency_tab()
# Input chat nằm dưới cùng
user_prompt = st.chat_input("Ask AI to control the dashboard...")
if user_prompt:
st.session_state.chat_history.append({"role": "user", "parts": user_prompt})
st.rerun()
# Xử lý câu hỏi của người dùng và hiển thị phản hồi AI
if st.session_state.chat_history and st.session_state.chat_history[-1]["role"] == "user":
last_user_prompt = st.session_state.chat_history[-1]["parts"]
with chat_container:
with st.chat_message("model"):
with st.spinner("🤖 AI executing command..."):
response = st.session_state.chat_session.send_message(last_user_prompt)
tool_calls = [part.function_call for part in response.candidates[0].content.parts if part.function_call]
while tool_calls:
tool_responses = []
for call in tool_calls:
func_name = call.name; func_args = {k: v for k, v in call.args.items()}
if func_name in AVAILABLE_FUNCTIONS:
tool_result = AVAILABLE_FUNCTIONS[func_name](**func_args)
tool_responses.append(glm.Part(function_response=glm.FunctionResponse(name=func_name, response={'result': tool_result})))
else:
tool_responses.append(glm.Part(function_response=glm.FunctionResponse(name=func_name, response={'error': f"Function '{func_name}' not found."})))
response = st.session_state.chat_session.send_message(glm.Content(parts=tool_responses))
tool_calls = [part.function_call for part in response.candidates[0].content.parts if part.function_call]
# Tìm kiếm từ khóa thời gian trong prompt của người dùng
old_period = st.session_state.active_timeseries_period
if last_user_prompt and "last year" in last_user_prompt.lower():
st.session_state.active_timeseries_period = "1_year"
elif last_user_prompt and "last 6 months" in last_user_prompt.lower():
st.session_state.active_timeseries_period = "6_months"
elif last_user_prompt and "last month" in last_user_prompt.lower():
st.session_state.active_timeseries_period = "1_month"
elif last_user_prompt and "last week" in last_user_prompt.lower():
st.session_state.active_timeseries_period = "1_week"
# Nếu thời gian thay đổi và có cổ phiếu trong watchlist, cập nhật dữ liệu
if old_period != st.session_state.active_timeseries_period and st.session_state.stock_watchlist:
new_period = st.session_state.active_timeseries_period
for symbol in st.session_state.stock_watchlist.keys():
if symbol not in st.session_state.timeseries_cache or new_period not in st.session_state.timeseries_cache[symbol]:
ts_data = get_smart_time_series(symbol, new_period)
st.session_state.active_tab = 'Time Charts'
st.session_state.chat_history.append({"role": "model", "parts": response.text})
st.rerun() |