Spaces:
Sleeping
Sleeping
# app.py (Phiên bản cuối cùng với Biểu đồ Nâng cao) | |
import streamlit as st | |
import pandas as pd | |
import altair as alt # <-- Thêm thư viện Altair | |
import google.generativeai as genai | |
import google.ai.generativelanguage as glm | |
from dotenv import load_dotenv | |
import os | |
from twelvedata_api import TwelveDataAPI | |
from collections import deque | |
from datetime import datetime | |
# --- 1. CẤU HÌNH BAN ĐẦU & KHỞI TẠO STATE --- | |
load_dotenv() | |
st.set_page_config(layout="wide", page_title="AI Financial Dashboard") | |
genai.configure(api_key=os.getenv("GEMINI_API_KEY")) | |
def initialize_state(): | |
if "initialized" in st.session_state: return | |
st.session_state.initialized = True | |
st.session_state.td_api = TwelveDataAPI(os.getenv("TWELVEDATA_API_KEY")) | |
st.session_state.stock_watchlist = {} | |
st.session_state.timeseries_cache = {} | |
st.session_state.active_timeseries_period = 'intraday' | |
st.session_state.currency_converter_state = {'from': 'USD', 'to': 'VND', 'amount': 100.0, 'result': None} | |
st.session_state.chat_history = [] | |
st.session_state.active_tab = 'Danh sách mã chứng khoán' | |
st.session_state.chat_session = None | |
initialize_state() | |
# --- 2. TẢI DỮ LIỆU NỀN --- | |
def load_market_data(): | |
td_api = st.session_state.td_api | |
stocks_data = td_api.get_all_stocks() | |
forex_data = td_api.get_forex_pairs() | |
forex_graph = {} | |
if forex_data and 'data' in forex_data: | |
for pair in forex_data['data']: | |
base, quote = pair['symbol'].split('/'); forex_graph.setdefault(base, []); forex_graph.setdefault(quote, []); forex_graph[base].append(quote); forex_graph[quote].append(base) | |
country_currency_map = {} | |
if stocks_data and 'data' in stocks_data: | |
for stock in stocks_data['data']: | |
country, currency = stock.get('country'), stock.get('currency') | |
if country and currency: country_currency_map[country.lower()] = currency | |
all_currencies = sorted(forex_graph.keys()) | |
return stocks_data, forex_graph, country_currency_map, all_currencies | |
ALL_STOCKS_CACHE, FOREX_GRAPH, COUNTRY_CURRENCY_MAP, AVAILABLE_CURRENCIES = load_market_data() | |
# --- 3. LOGIC THỰC THI TOOL --- | |
def find_and_process_stock(query: str): | |
print(f"Hybrid searching for stock: '{query}'...") | |
query_lower = query.lower() | |
found_data = [s for s in ALL_STOCKS_CACHE.get('data', []) if query_lower in s['symbol'].lower() or query_lower in s['name'].lower()] | |
if not found_data: | |
results = st.session_state.td_api.get_stocks(symbol=query) | |
found_data = results.get('data', []) | |
if len(found_data) == 1: | |
stock_info = found_data[0]; symbol = stock_info['symbol'] | |
st.session_state.stock_watchlist[symbol] = stock_info | |
ts_data = get_smart_time_series(symbol=symbol, time_period='intraday') | |
if 'values' in ts_data: | |
df = pd.DataFrame(ts_data['values']); df['datetime'] = pd.to_datetime(df['datetime']); df['close'] = pd.to_numeric(df['close']) | |
if symbol not in st.session_state.timeseries_cache: st.session_state.timeseries_cache[symbol] = {} | |
st.session_state.timeseries_cache[symbol]['intraday'] = df.sort_values('datetime').set_index('datetime') | |
st.session_state.active_tab = 'Biểu đồ thời gian'; st.session_state.active_timeseries_period = 'intraday' | |
return {"status": "SINGLE_STOCK_PROCESSED", "symbol": symbol, "name": stock_info.get('name', 'N/A')} | |
elif len(found_data) > 1: return {"status": "MULTIPLE_STOCKS_FOUND", "data": found_data[:5]} | |
else: return {"status": "NO_STOCKS_FOUND"} | |
def get_smart_time_series(symbol: str, time_period: str): | |
logic_map = {'intraday': {'interval': '15min', 'outputsize': 120}, '1_week': {'interval': '1h', 'outputsize': 40}, '1_month': {'interval': '1day', 'outputsize': 22}, '6_months': {'interval': '1day', 'outputsize': 120}, '1_year': {'interval': '1week', 'outputsize': 52}} | |
params = logic_map.get(time_period) | |
if not params: return {"error": f"Khoảng thời gian '{time_period}' không hợp lệ."} | |
return st.session_state.td_api.get_time_series(symbol=symbol, **params) | |
def find_conversion_path_bfs(start, end): | |
if start not in FOREX_GRAPH or end not in FOREX_GRAPH: return None | |
q = deque([(start, [start])]); visited = {start} | |
while q: | |
curr, path = q.popleft() | |
if curr == end: return path | |
for neighbor in FOREX_GRAPH.get(curr, []): | |
if neighbor not in visited: visited.add(neighbor); q.append((neighbor, path + [neighbor])) | |
return None | |
def convert_currency_with_bridge(amount: float, symbol: str): | |
try: start_currency, end_currency = symbol.upper().split('/') | |
except ValueError: return {"error": "Định dạng cặp tiền tệ không hợp lệ."} | |
path = find_conversion_path_bfs(start_currency, end_currency) | |
if not path: return {"error": f"Không tìm thấy đường đi quy đổi từ {start_currency} sang {end_currency}."} | |
current_amount = amount; steps = [] | |
for i in range(len(path) - 1): | |
step_start, step_end = path[i], path[i+1] | |
result = st.session_state.td_api.currency_conversion(amount=current_amount, symbol=f"{step_start}/{step_end}") | |
if 'rate' in result and result.get('rate') is not None: | |
current_amount = result['amount']; steps.append({"step": f"{i+1}. {step_start} → {step_end}", "rate": result['rate'], "intermediate_amount": current_amount}) | |
else: | |
inverse_result = st.session_state.td_api.currency_conversion(amount=1, symbol=f"{step_end}/{step_start}") | |
if 'rate' in inverse_result and inverse_result.get('rate') and inverse_result['rate'] != 0: | |
rate = 1 / inverse_result['rate']; current_amount *= rate; steps.append({"step": f"{i+1}. {step_start} → {step_end} (Inverse)", "rate": rate, "intermediate_amount": current_amount}) | |
else: return {"error": f"Lỗi ở bước quy đổi từ {step_start} sang {step_end}."} | |
return {"status": "Success", "original_amount": amount, "final_amount": current_amount, "path_taken": path, "conversion_steps": steps} | |
def perform_currency_conversion(amount: float, symbol: str): | |
result = convert_currency_with_bridge(amount, symbol) | |
st.session_state.currency_converter_state.update({'result': result, 'amount': amount}) | |
try: | |
from_curr, to_curr = symbol.split('/'); st.session_state.currency_converter_state.update({'from': from_curr, 'to': to_curr}) | |
except: pass | |
st.session_state.active_tab = 'Quy đổi tiền tệ' | |
return result | |
# --- 4. CẤU HÌNH GEMINI --- | |
SYSTEM_INSTRUCTION = """Bạn là bộ não AI điều khiển một Bảng điều khiển Tài chính Tương tác. Nhiệm vụ của bạn là hiểu yêu cầu của người dùng, gọi các công cụ phù hợp, và thông báo kết quả một cách súc tích. | |
QUY TẮC VÀNG: | |
1. **HIỂU TRƯỚC, GỌI SAU:** | |
* **Tên công ty:** Khi người dùng nhập một tên công ty (ví dụ: "Tập đoàn Vingroup", "Apple"), nhiệm vụ ĐẦU TIÊN của bạn là dùng tool `find_and_process_stock` để xác định mã chứng khoán chính thức. | |
* **Tên quốc gia:** Khi người dùng nhập tên quốc gia cho tiền tệ (ví dụ: "tiền Việt Nam"), bạn phải tự suy luận ra mã tiền tệ 3 chữ cái ("VND") TRƯỚC KHI gọi tool `perform_currency_conversion`. | |
2. **HÀNH ĐỘNG VÀ THÔNG BÁO:** Vai trò của bạn là thực thi lệnh và thông báo ngắn gọn. | |
* **Tìm thấy 1 mã:** "Tôi đã tìm thấy [Tên công ty] ([Mã CK]) và đã tự động thêm vào danh sách theo dõi và biểu đồ của bạn." | |
* **Tìm thấy nhiều mã:** "Tôi tìm thấy một vài kết quả cho '[query]'. Bạn vui lòng cho biết mã chính xác bạn muốn theo dõi?" | |
* **Quy đổi tiền tệ:** "Đã thực hiện. Mời bạn xem kết quả chi tiết trong tab 'Quy đổi tiền tệ'." | |
3. **CẤM LIỆT KÊ DỮ LIỆU:** Bảng điều khiển đã hiển thị tất cả. TUYỆT ĐỐI không lặp lại danh sách, các con số, hay dữ liệu thô trong câu trả lời của bạn. | |
""" | |
def get_model_and_tools(): | |
find_stock_func = glm.FunctionDeclaration(name="find_and_process_stock", description="Tìm kiếm cổ phiếu theo mã hoặc tên và tự động xử lý. Dùng tool này ĐẦU TIÊN để xác định mã CK chính thức.", parameters=glm.Schema(type=glm.Type.OBJECT, properties={'query': glm.Schema(type=glm.Type.STRING, description="Mã hoặc tên công ty, ví dụ: 'Vingroup', 'Apple'.")}, required=['query'])) | |
get_ts_func = glm.FunctionDeclaration(name="get_smart_time_series", description="Lấy dữ liệu lịch sử giá sau khi đã biết mã CK chính thức.", parameters=glm.Schema(type=glm.Type.OBJECT, properties={'symbol': glm.Schema(type=glm.Type.STRING), 'time_period': glm.Schema(type=glm.Type.STRING, enum=["intraday", "1_week", "1_month", "6_months", "1_year"])}, required=['symbol', 'time_period'])) | |
currency_func = glm.FunctionDeclaration(name="perform_currency_conversion", description="Quy đổi tiền tệ sau khi đã biết mã 3 chữ cái của cặp tiền tệ nguồn/đích, ví dụ USD/VND, JPY/EUR", parameters=glm.Schema(type=glm.Type.OBJECT, properties={'amount': glm.Schema(type=glm.Type.NUMBER), 'symbol': glm.Schema(type=glm.Type.STRING)}, required=['amount', 'symbol'])) | |
finance_tool = glm.Tool(function_declarations=[find_stock_func, get_ts_func, currency_func]) | |
model = genai.GenerativeModel(model_name="gemini-1.5-pro-latest", tools=[finance_tool], system_instruction=SYSTEM_INSTRUCTION) | |
return model | |
model = get_model_and_tools() | |
if st.session_state.chat_session is None: | |
st.session_state.chat_session = model.start_chat(history=[]) | |
AVAILABLE_FUNCTIONS = {"find_and_process_stock": find_and_process_stock, "get_smart_time_series": get_smart_time_series, "perform_currency_conversion": perform_currency_conversion} | |
# --- 5. LOGIC HIỂN THỊ CÁC TAB --- | |
def get_y_axis_domain(series: pd.Series, padding_percent: float = 0.1): | |
if series.empty: return None | |
data_min, data_max = series.min(), series.max() | |
if pd.isna(data_min) or pd.isna(data_max): return None | |
data_range = data_max - data_min | |
if data_range == 0: | |
padding = abs(data_max * (padding_percent / 2)) | |
return [data_min - padding, data_max + padding] | |
padding = data_range * padding_percent | |
return [data_min - padding, data_max + padding] | |
def render_watchlist_tab(): | |
st.subheader("Danh sách theo dõi") | |
if not st.session_state.stock_watchlist: st.info("Chưa có cổ phiếu nào. Hãy thử tìm kiếm một mã như 'Apple' hoặc 'VNM'."); return | |
for symbol, stock_info in list(st.session_state.stock_watchlist.items()): | |
col1, col2, col3 = st.columns([4, 4, 1]) | |
with col1: st.markdown(f"**{symbol}**"); st.caption(stock_info.get('name', 'N/A')) | |
with col2: st.markdown(f"**{stock_info.get('exchange', 'N/A')}**"); st.caption(f"{stock_info.get('country', 'N/A')} - {stock_info.get('currency', 'N/A')}") | |
with col3: | |
if st.button("🗑️", key=f"delete_{symbol}", help=f"Xóa {symbol}"): | |
st.session_state.stock_watchlist.pop(symbol, None); st.session_state.timeseries_cache.pop(symbol, None); st.rerun() | |
st.divider() | |
def render_timeseries_tab(): | |
st.subheader("Phân tích Biểu đồ") | |
if not st.session_state.stock_watchlist: | |
st.info("Hãy thêm ít nhất một cổ phiếu vào danh sách để xem biểu đồ."); return | |
time_periods = {'Trong ngày': 'intraday', '1 Tuần': '1_week', '1 Tháng': '1_month', '6 Tháng': '6_months', '1 Năm': '1_year'} | |
period_keys = list(time_periods.keys()) | |
period_values = list(time_periods.values()) | |
default_index = period_values.index(st.session_state.active_timeseries_period) if st.session_state.active_timeseries_period in period_values else 0 | |
selected_label = st.radio("Chọn khoảng thời gian:", options=period_keys, horizontal=True, index=default_index) | |
selected_period = time_periods[selected_label] | |
if st.session_state.active_timeseries_period != selected_period: | |
st.session_state.active_timeseries_period = selected_period | |
with st.spinner(f"Đang cập nhật biểu đồ..."): | |
for symbol in st.session_state.stock_watchlist.keys(): | |
ts_data = get_smart_time_series(symbol, selected_period) | |
if 'values' in ts_data: | |
df = pd.DataFrame(ts_data['values']); df['datetime'] = pd.to_datetime(df['datetime']); df['close'] = pd.to_numeric(df['close']) | |
if symbol not in st.session_state.timeseries_cache: st.session_state.timeseries_cache[symbol] = {} | |
st.session_state.timeseries_cache[symbol][selected_period] = df.sort_values('datetime').set_index('datetime') | |
st.rerun() | |
all_series_data = {symbol: st.session_state.timeseries_cache[symbol][selected_period] for symbol in st.session_state.stock_watchlist.keys() if symbol in st.session_state.timeseries_cache and selected_period in st.session_state.timeseries_cache[symbol]} | |
if not all_series_data: | |
st.warning("Không có đủ dữ liệu cho khoảng thời gian đã chọn."); return | |
st.markdown("##### So sánh Hiệu suất Tăng trưởng (%)") | |
normalized_dfs = [] | |
for symbol, df in all_series_data.items(): | |
if not df.empty: | |
normalized_series = (df['close'] / df['close'].iloc[0]) * 100 | |
normalized_df = normalized_series.reset_index(); normalized_df.columns = ['datetime', 'value']; normalized_df['symbol'] = symbol | |
normalized_dfs.append(normalized_df) | |
if normalized_dfs: | |
full_normalized_df = pd.concat(normalized_dfs) | |
y_domain = get_y_axis_domain(full_normalized_df['value']) | |
chart = alt.Chart(full_normalized_df).mark_line().encode(x=alt.X('datetime:T', title='Thời gian'), y=alt.Y('value:Q', scale=alt.Scale(domain=y_domain, zero=False), title='Tăng trưởng (%)'), color=alt.Color('symbol:N', title='Mã CK'), tooltip=[alt.Tooltip('symbol:N', title='Mã'), alt.Tooltip('datetime:T', title='Thời điểm', format='%Y-%m-%d %H:%M'), alt.Tooltip('value:Q', title='Tăng trưởng', format='.2f')]).interactive() | |
st.altair_chart(chart, use_container_width=True) | |
else: | |
st.warning("Không có dữ liệu để vẽ biểu đồ tăng trưởng.") | |
st.divider() | |
st.markdown("##### Biểu đồ Giá Thực tế") | |
for symbol, df in all_series_data.items(): | |
stock_info = st.session_state.stock_watchlist.get(symbol, {}) | |
st.markdown(f"**{symbol}** ({stock_info.get('currency', 'N/A')})") | |
if not df.empty: | |
y_domain = get_y_axis_domain(df['close']) | |
data_for_chart = df.reset_index() | |
price_chart = alt.Chart(data_for_chart).mark_line().encode(x=alt.X('datetime:T', title='Thời gian'), y=alt.Y('close:Q', scale=alt.Scale(domain=y_domain, zero=False), title='Giá'), tooltip=[alt.Tooltip('datetime:T', title='Thời điểm', format='%Y-%m-%d %H:%M'), alt.Tooltip('close:Q', title='Giá', format=',.2f')]).interactive() | |
st.altair_chart(price_chart, use_container_width=True) | |
def render_currency_tab(): | |
st.subheader("Công cụ quy đổi tiền tệ"); state = st.session_state.currency_converter_state | |
col1, col2 = st.columns(2) | |
amount = col1.number_input("Số tiền", value=state['amount'], min_value=0.0, format="%.2f", key="conv_amount") | |
from_curr = col1.selectbox("Từ", options=AVAILABLE_CURRENCIES, index=AVAILABLE_CURRENCIES.index(state['from']) if state['from'] in AVAILABLE_CURRENCIES else 0, key="conv_from") | |
to_curr = col2.selectbox("Sang", options=AVAILABLE_CURRENCIES, index=AVAILABLE_CURRENCIES.index(state['to']) if state['to'] in AVAILABLE_CURRENCIES else 1, key="conv_to") | |
if st.button("Quy đổi", use_container_width=True, key="conv_btn"): | |
with st.spinner("Đang quy đổi..."): result = perform_currency_conversion(amount, f"{from_curr}/{to_curr}"); st.rerun() | |
if state['result']: | |
res = state['result'] | |
if res.get('status') == 'Success': st.success(f"**Kết quả:** `{res['original_amount']:,.2f} {res['path_taken'][0]}` = `{res['final_amount']:,.2f} {res['path_taken'][-1]}`") | |
else: st.error(f"Lỗi: {res.get('error', 'Không rõ')}") | |
# --- 6. MAIN APP LAYOUT & CONTROL FLOW --- | |
st.title("📈 AI Financial Dashboard") | |
col1, col2 = st.columns([1, 1]) | |
with col2: | |
right_column_container = st.container(height=600) | |
with right_column_container: | |
tab_names = ['Danh sách mã chứng khoán', 'Biểu đồ thời gian', 'Quy đổi tiền tệ'] | |
try: default_index = tab_names.index(st.session_state.active_tab) | |
except ValueError: default_index = 0 | |
st.session_state.active_tab = tab_names[default_index] | |
tab1, tab2, tab3 = st.tabs(tab_names) | |
with tab1: render_watchlist_tab() | |
with tab2: render_timeseries_tab() | |
with tab3: render_currency_tab() | |
with col1: | |
chat_container = st.container(height=600) | |
with chat_container: | |
for message in st.session_state.chat_history: | |
with st.chat_message(message["role"]): | |
st.markdown(message["parts"]) | |
user_prompt = st.chat_input("Hỏi AI để điều khiển bảng điều khiển...") | |
if user_prompt: | |
st.session_state.chat_history.append({"role": "user", "parts": user_prompt}) | |
st.rerun() | |
if st.session_state.chat_history and st.session_state.chat_history[-1]["role"] == "user": | |
last_user_prompt = st.session_state.chat_history[-1]["parts"] | |
# ***** ĐÂY LÀ PHẦN THAY ĐỔI ***** | |
with chat_container: | |
with st.chat_message("model"): | |
with st.spinner("🤖 AI đang thực thi lệnh..."): | |
response = st.session_state.chat_session.send_message(last_user_prompt) | |
tool_calls = [part.function_call for part in response.candidates[0].content.parts if part.function_call] | |
while tool_calls: | |
tool_responses = [] | |
for call in tool_calls: | |
func_name = call.name; func_args = {k: v for k, v in call.args.items()} | |
if func_name in AVAILABLE_FUNCTIONS: | |
tool_result = AVAILABLE_FUNCTIONS[func_name](**func_args) | |
tool_responses.append(glm.Part(function_response=glm.FunctionResponse(name=func_name, response={'result': tool_result}))) | |
else: | |
tool_responses.append(glm.Part(function_response=glm.FunctionResponse(name=func_name, response={'error': f"Function '{func_name}' not found."}))) | |
response = st.session_state.chat_session.send_message(glm.Content(parts=tool_responses)) | |
tool_calls = [part.function_call for part in response.candidates[0].content.parts if part.function_call] | |
st.session_state.chat_history.append({"role": "model", "parts": response.text}) | |
st.rerun() |