File size: 19,402 Bytes
33062e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
# app.py (Phiên bản cuối cùng với Biểu đồ Nâng cao)

import streamlit as st
import pandas as pd
import altair as alt # <-- Thêm thư viện Altair
import google.generativeai as genai
import google.ai.generativelanguage as glm
from dotenv import load_dotenv
import os
from twelvedata_api import TwelveDataAPI
from collections import deque
from datetime import datetime

# --- 1. CẤU HÌNH BAN ĐẦU & KHỞI TẠO STATE ---
load_dotenv()
st.set_page_config(layout="wide", page_title="AI Financial Dashboard")
genai.configure(api_key=os.getenv("GEMINI_API_KEY"))

def initialize_state():
    if "initialized" in st.session_state: return
    st.session_state.initialized = True
    st.session_state.td_api = TwelveDataAPI(os.getenv("TWELVEDATA_API_KEY"))
    st.session_state.stock_watchlist = {}
    st.session_state.timeseries_cache = {}
    st.session_state.active_timeseries_period = 'intraday' 
    st.session_state.currency_converter_state = {'from': 'USD', 'to': 'VND', 'amount': 100.0, 'result': None}
    st.session_state.chat_history = []
    st.session_state.active_tab = 'Danh sách mã chứng khoán'
    st.session_state.chat_session = None
initialize_state()

# --- 2. TẢI DỮ LIỆU NỀN ---
@st.cache_data(show_spinner="Đang tải và chuẩn bị dữ liệu thị trường...")
def load_market_data():
    td_api = st.session_state.td_api
    stocks_data = td_api.get_all_stocks()
    forex_data = td_api.get_forex_pairs()
    forex_graph = {}
    if forex_data and 'data' in forex_data:
        for pair in forex_data['data']:
            base, quote = pair['symbol'].split('/'); forex_graph.setdefault(base, []); forex_graph.setdefault(quote, []); forex_graph[base].append(quote); forex_graph[quote].append(base)
    country_currency_map = {}
    if stocks_data and 'data' in stocks_data:
        for stock in stocks_data['data']:
            country, currency = stock.get('country'), stock.get('currency')
            if country and currency: country_currency_map[country.lower()] = currency
    all_currencies = sorted(forex_graph.keys())
    return stocks_data, forex_graph, country_currency_map, all_currencies
ALL_STOCKS_CACHE, FOREX_GRAPH, COUNTRY_CURRENCY_MAP, AVAILABLE_CURRENCIES = load_market_data()

# --- 3. LOGIC THỰC THI TOOL ---
def find_and_process_stock(query: str):
    print(f"Hybrid searching for stock: '{query}'...")
    query_lower = query.lower()
    found_data = [s for s in ALL_STOCKS_CACHE.get('data', []) if query_lower in s['symbol'].lower() or query_lower in s['name'].lower()]
    if not found_data:
        results = st.session_state.td_api.get_stocks(symbol=query)
        found_data = results.get('data', [])
    if len(found_data) == 1:
        stock_info = found_data[0]; symbol = stock_info['symbol']
        st.session_state.stock_watchlist[symbol] = stock_info
        ts_data = get_smart_time_series(symbol=symbol, time_period='intraday')
        if 'values' in ts_data:
            df = pd.DataFrame(ts_data['values']); df['datetime'] = pd.to_datetime(df['datetime']); df['close'] = pd.to_numeric(df['close'])
            if symbol not in st.session_state.timeseries_cache: st.session_state.timeseries_cache[symbol] = {}
            st.session_state.timeseries_cache[symbol]['intraday'] = df.sort_values('datetime').set_index('datetime')
        st.session_state.active_tab = 'Biểu đồ thời gian'; st.session_state.active_timeseries_period = 'intraday'
        return {"status": "SINGLE_STOCK_PROCESSED", "symbol": symbol, "name": stock_info.get('name', 'N/A')}
    elif len(found_data) > 1: return {"status": "MULTIPLE_STOCKS_FOUND", "data": found_data[:5]}
    else: return {"status": "NO_STOCKS_FOUND"}
def get_smart_time_series(symbol: str, time_period: str):
    logic_map = {'intraday': {'interval': '15min', 'outputsize': 120}, '1_week': {'interval': '1h', 'outputsize': 40}, '1_month': {'interval': '1day', 'outputsize': 22}, '6_months': {'interval': '1day', 'outputsize': 120}, '1_year': {'interval': '1week', 'outputsize': 52}}
    params = logic_map.get(time_period)
    if not params: return {"error": f"Khoảng thời gian '{time_period}' không hợp lệ."}
    return st.session_state.td_api.get_time_series(symbol=symbol, **params)
def find_conversion_path_bfs(start, end):
    if start not in FOREX_GRAPH or end not in FOREX_GRAPH: return None
    q = deque([(start, [start])]); visited = {start}
    while q:
        curr, path = q.popleft()
        if curr == end: return path
        for neighbor in FOREX_GRAPH.get(curr, []):
            if neighbor not in visited: visited.add(neighbor); q.append((neighbor, path + [neighbor]))
    return None
def convert_currency_with_bridge(amount: float, symbol: str):
    try: start_currency, end_currency = symbol.upper().split('/')
    except ValueError: return {"error": "Định dạng cặp tiền tệ không hợp lệ."}
    path = find_conversion_path_bfs(start_currency, end_currency)
    if not path: return {"error": f"Không tìm thấy đường đi quy đổi từ {start_currency} sang {end_currency}."}
    current_amount = amount; steps = []
    for i in range(len(path) - 1):
        step_start, step_end = path[i], path[i+1]
        result = st.session_state.td_api.currency_conversion(amount=current_amount, symbol=f"{step_start}/{step_end}")
        if 'rate' in result and result.get('rate') is not None:
            current_amount = result['amount']; steps.append({"step": f"{i+1}. {step_start}{step_end}", "rate": result['rate'], "intermediate_amount": current_amount})
        else:
            inverse_result = st.session_state.td_api.currency_conversion(amount=1, symbol=f"{step_end}/{step_start}")
            if 'rate' in inverse_result and inverse_result.get('rate') and inverse_result['rate'] != 0:
                rate = 1 / inverse_result['rate']; current_amount *= rate; steps.append({"step": f"{i+1}. {step_start}{step_end} (Inverse)", "rate": rate, "intermediate_amount": current_amount})
            else: return {"error": f"Lỗi ở bước quy đổi từ {step_start} sang {step_end}."}
    return {"status": "Success", "original_amount": amount, "final_amount": current_amount, "path_taken": path, "conversion_steps": steps}
def perform_currency_conversion(amount: float, symbol: str):
    result = convert_currency_with_bridge(amount, symbol)
    st.session_state.currency_converter_state.update({'result': result, 'amount': amount})
    try:
        from_curr, to_curr = symbol.split('/'); st.session_state.currency_converter_state.update({'from': from_curr, 'to': to_curr})
    except: pass
    st.session_state.active_tab = 'Quy đổi tiền tệ'
    return result

# --- 4. CẤU HÌNH GEMINI ---
SYSTEM_INSTRUCTION = """Bạn là bộ não AI điều khiển một Bảng điều khiển Tài chính Tương tác. Nhiệm vụ của bạn là hiểu yêu cầu của người dùng, gọi các công cụ phù hợp, và thông báo kết quả một cách súc tích.

QUY TẮC VÀNG:
1.  **HIỂU TRƯỚC, GỌI SAU:**
    *   **Tên công ty:** Khi người dùng nhập một tên công ty (ví dụ: "Tập đoàn Vingroup", "Apple"), nhiệm vụ ĐẦU TIÊN của bạn là dùng tool `find_and_process_stock` để xác định mã chứng khoán chính thức.
    *   **Tên quốc gia:** Khi người dùng nhập tên quốc gia cho tiền tệ (ví dụ: "tiền Việt Nam"), bạn phải tự suy luận ra mã tiền tệ 3 chữ cái ("VND") TRƯỚC KHI gọi tool `perform_currency_conversion`.
2.  **HÀNH ĐỘNG VÀ THÔNG BÁO:** Vai trò của bạn là thực thi lệnh và thông báo ngắn gọn.
    *   **Tìm thấy 1 mã:** "Tôi đã tìm thấy [Tên công ty] ([Mã CK]) và đã tự động thêm vào danh sách theo dõi và biểu đồ của bạn."
    *   **Tìm thấy nhiều mã:** "Tôi tìm thấy một vài kết quả cho '[query]'. Bạn vui lòng cho biết mã chính xác bạn muốn theo dõi?"
    *   **Quy đổi tiền tệ:** "Đã thực hiện. Mời bạn xem kết quả chi tiết trong tab 'Quy đổi tiền tệ'."
3.  **CẤM LIỆT KÊ DỮ LIỆU:** Bảng điều khiển đã hiển thị tất cả. TUYỆT ĐỐI không lặp lại danh sách, các con số, hay dữ liệu thô trong câu trả lời của bạn.
"""
@st.cache_resource
def get_model_and_tools():
    find_stock_func = glm.FunctionDeclaration(name="find_and_process_stock", description="Tìm kiếm cổ phiếu theo mã hoặc tên và tự động xử lý. Dùng tool này ĐẦU TIÊN để xác định mã CK chính thức.", parameters=glm.Schema(type=glm.Type.OBJECT, properties={'query': glm.Schema(type=glm.Type.STRING, description="Mã hoặc tên công ty, ví dụ: 'Vingroup', 'Apple'.")}, required=['query']))
    get_ts_func = glm.FunctionDeclaration(name="get_smart_time_series", description="Lấy dữ liệu lịch sử giá sau khi đã biết mã CK chính thức.", parameters=glm.Schema(type=glm.Type.OBJECT, properties={'symbol': glm.Schema(type=glm.Type.STRING), 'time_period': glm.Schema(type=glm.Type.STRING, enum=["intraday", "1_week", "1_month", "6_months", "1_year"])}, required=['symbol', 'time_period']))
    currency_func = glm.FunctionDeclaration(name="perform_currency_conversion", description="Quy đổi tiền tệ sau khi đã biết mã 3 chữ cái của cặp tiền tệ nguồn/đích, ví dụ USD/VND, JPY/EUR", parameters=glm.Schema(type=glm.Type.OBJECT, properties={'amount': glm.Schema(type=glm.Type.NUMBER), 'symbol': glm.Schema(type=glm.Type.STRING)}, required=['amount', 'symbol']))
    finance_tool = glm.Tool(function_declarations=[find_stock_func, get_ts_func, currency_func])
    model = genai.GenerativeModel(model_name="gemini-1.5-pro-latest", tools=[finance_tool], system_instruction=SYSTEM_INSTRUCTION)
    return model
model = get_model_and_tools()
if st.session_state.chat_session is None:
    st.session_state.chat_session = model.start_chat(history=[])
AVAILABLE_FUNCTIONS = {"find_and_process_stock": find_and_process_stock, "get_smart_time_series": get_smart_time_series, "perform_currency_conversion": perform_currency_conversion}

# --- 5. LOGIC HIỂN THỊ CÁC TAB ---
def get_y_axis_domain(series: pd.Series, padding_percent: float = 0.1):
    if series.empty: return None
    data_min, data_max = series.min(), series.max()
    if pd.isna(data_min) or pd.isna(data_max): return None
    data_range = data_max - data_min
    if data_range == 0:
        padding = abs(data_max * (padding_percent / 2))
        return [data_min - padding, data_max + padding]
    padding = data_range * padding_percent
    return [data_min - padding, data_max + padding]

def render_watchlist_tab():
    st.subheader("Danh sách theo dõi")
    if not st.session_state.stock_watchlist: st.info("Chưa có cổ phiếu nào. Hãy thử tìm kiếm một mã như 'Apple' hoặc 'VNM'."); return
    for symbol, stock_info in list(st.session_state.stock_watchlist.items()):
        col1, col2, col3 = st.columns([4, 4, 1])
        with col1: st.markdown(f"**{symbol}**"); st.caption(stock_info.get('name', 'N/A'))
        with col2: st.markdown(f"**{stock_info.get('exchange', 'N/A')}**"); st.caption(f"{stock_info.get('country', 'N/A')} - {stock_info.get('currency', 'N/A')}")
        with col3:
            if st.button("🗑️", key=f"delete_{symbol}", help=f"Xóa {symbol}"):
                st.session_state.stock_watchlist.pop(symbol, None); st.session_state.timeseries_cache.pop(symbol, None); st.rerun()
        st.divider()

def render_timeseries_tab():
    st.subheader("Phân tích Biểu đồ")
    if not st.session_state.stock_watchlist:
        st.info("Hãy thêm ít nhất một cổ phiếu vào danh sách để xem biểu đồ."); return
    time_periods = {'Trong ngày': 'intraday', '1 Tuần': '1_week', '1 Tháng': '1_month', '6 Tháng': '6_months', '1 Năm': '1_year'}
    period_keys = list(time_periods.keys())
    period_values = list(time_periods.values())
    default_index = period_values.index(st.session_state.active_timeseries_period) if st.session_state.active_timeseries_period in period_values else 0
    selected_label = st.radio("Chọn khoảng thời gian:", options=period_keys, horizontal=True, index=default_index)
    selected_period = time_periods[selected_label]
    if st.session_state.active_timeseries_period != selected_period:
        st.session_state.active_timeseries_period = selected_period
        with st.spinner(f"Đang cập nhật biểu đồ..."):
            for symbol in st.session_state.stock_watchlist.keys():
                ts_data = get_smart_time_series(symbol, selected_period)
                if 'values' in ts_data:
                    df = pd.DataFrame(ts_data['values']); df['datetime'] = pd.to_datetime(df['datetime']); df['close'] = pd.to_numeric(df['close'])
                    if symbol not in st.session_state.timeseries_cache: st.session_state.timeseries_cache[symbol] = {}
                    st.session_state.timeseries_cache[symbol][selected_period] = df.sort_values('datetime').set_index('datetime')
        st.rerun()
    all_series_data = {symbol: st.session_state.timeseries_cache[symbol][selected_period] for symbol in st.session_state.stock_watchlist.keys() if symbol in st.session_state.timeseries_cache and selected_period in st.session_state.timeseries_cache[symbol]}
    if not all_series_data:
        st.warning("Không có đủ dữ liệu cho khoảng thời gian đã chọn."); return
    st.markdown("##### So sánh Hiệu suất Tăng trưởng (%)")
    normalized_dfs = []
    for symbol, df in all_series_data.items():
        if not df.empty:
            normalized_series = (df['close'] / df['close'].iloc[0]) * 100
            normalized_df = normalized_series.reset_index(); normalized_df.columns = ['datetime', 'value']; normalized_df['symbol'] = symbol
            normalized_dfs.append(normalized_df)
    if normalized_dfs:
        full_normalized_df = pd.concat(normalized_dfs)
        y_domain = get_y_axis_domain(full_normalized_df['value'])
        chart = alt.Chart(full_normalized_df).mark_line().encode(x=alt.X('datetime:T', title='Thời gian'), y=alt.Y('value:Q', scale=alt.Scale(domain=y_domain, zero=False), title='Tăng trưởng (%)'), color=alt.Color('symbol:N', title='Mã CK'), tooltip=[alt.Tooltip('symbol:N', title='Mã'), alt.Tooltip('datetime:T', title='Thời điểm', format='%Y-%m-%d %H:%M'), alt.Tooltip('value:Q', title='Tăng trưởng', format='.2f')]).interactive()
        st.altair_chart(chart, use_container_width=True)
    else:
        st.warning("Không có dữ liệu để vẽ biểu đồ tăng trưởng.")
    st.divider()
    st.markdown("##### Biểu đồ Giá Thực tế")
    for symbol, df in all_series_data.items():
        stock_info = st.session_state.stock_watchlist.get(symbol, {})
        st.markdown(f"**{symbol}** ({stock_info.get('currency', 'N/A')})")
        if not df.empty:
            y_domain = get_y_axis_domain(df['close'])
            data_for_chart = df.reset_index()
            price_chart = alt.Chart(data_for_chart).mark_line().encode(x=alt.X('datetime:T', title='Thời gian'), y=alt.Y('close:Q', scale=alt.Scale(domain=y_domain, zero=False), title='Giá'), tooltip=[alt.Tooltip('datetime:T', title='Thời điểm', format='%Y-%m-%d %H:%M'), alt.Tooltip('close:Q', title='Giá', format=',.2f')]).interactive()
            st.altair_chart(price_chart, use_container_width=True)

def render_currency_tab():
    st.subheader("Công cụ quy đổi tiền tệ"); state = st.session_state.currency_converter_state
    col1, col2 = st.columns(2)
    amount = col1.number_input("Số tiền", value=state['amount'], min_value=0.0, format="%.2f", key="conv_amount")
    from_curr = col1.selectbox("Từ", options=AVAILABLE_CURRENCIES, index=AVAILABLE_CURRENCIES.index(state['from']) if state['from'] in AVAILABLE_CURRENCIES else 0, key="conv_from")
    to_curr = col2.selectbox("Sang", options=AVAILABLE_CURRENCIES, index=AVAILABLE_CURRENCIES.index(state['to']) if state['to'] in AVAILABLE_CURRENCIES else 1, key="conv_to")
    if st.button("Quy đổi", use_container_width=True, key="conv_btn"):
        with st.spinner("Đang quy đổi..."): result = perform_currency_conversion(amount, f"{from_curr}/{to_curr}"); st.rerun()
    if state['result']:
        res = state['result']
        if res.get('status') == 'Success': st.success(f"**Kết quả:** `{res['original_amount']:,.2f} {res['path_taken'][0]}` = `{res['final_amount']:,.2f} {res['path_taken'][-1]}`")
        else: st.error(f"Lỗi: {res.get('error', 'Không rõ')}")

# --- 6. MAIN APP LAYOUT & CONTROL FLOW ---
st.title("📈 AI Financial Dashboard")

col1, col2 = st.columns([1, 1])

with col2:
    right_column_container = st.container(height=600)
    with right_column_container:
        tab_names = ['Danh sách mã chứng khoán', 'Biểu đồ thời gian', 'Quy đổi tiền tệ']
        try: default_index = tab_names.index(st.session_state.active_tab)
        except ValueError: default_index = 0
        st.session_state.active_tab = tab_names[default_index]
        
        tab1, tab2, tab3 = st.tabs(tab_names)
        with tab1: render_watchlist_tab()
        with tab2: render_timeseries_tab()
        with tab3: render_currency_tab()

with col1:
    chat_container = st.container(height=600)
    with chat_container:
        for message in st.session_state.chat_history:
            with st.chat_message(message["role"]):
                st.markdown(message["parts"])

user_prompt = st.chat_input("Hỏi AI để điều khiển bảng điều khiển...")
if user_prompt:
    st.session_state.chat_history.append({"role": "user", "parts": user_prompt})
    st.rerun()

if st.session_state.chat_history and st.session_state.chat_history[-1]["role"] == "user":
    last_user_prompt = st.session_state.chat_history[-1]["parts"]
    
    # ***** ĐÂY LÀ PHẦN THAY ĐỔI *****
    with chat_container: 
        with st.chat_message("model"):
            with st.spinner("🤖 AI đang thực thi lệnh..."):
                response = st.session_state.chat_session.send_message(last_user_prompt)
                tool_calls = [part.function_call for part in response.candidates[0].content.parts if part.function_call]
                while tool_calls:
                    tool_responses = []
                    for call in tool_calls:
                        func_name = call.name; func_args = {k: v for k, v in call.args.items()}
                        if func_name in AVAILABLE_FUNCTIONS:
                            tool_result = AVAILABLE_FUNCTIONS[func_name](**func_args)
                            tool_responses.append(glm.Part(function_response=glm.FunctionResponse(name=func_name, response={'result': tool_result})))
                        else:
                            tool_responses.append(glm.Part(function_response=glm.FunctionResponse(name=func_name, response={'error': f"Function '{func_name}' not found."})))
                    response = st.session_state.chat_session.send_message(glm.Content(parts=tool_responses))
                    tool_calls = [part.function_call for part in response.candidates[0].content.parts if part.function_call]
                
                st.session_state.chat_history.append({"role": "model", "parts": response.text})
                st.rerun()