File size: 5,808 Bytes
1e2adc2
 
a45959e
1e2adc2
 
 
 
 
 
 
 
 
0091ee4
1e2adc2
ead29e0
fc08f05
ff03388
ead29e0
 
 
 
9c0b92a
ead29e0
ff03388
ead29e0
 
ff03388
ead29e0
9c0b92a
843b7e5
6d62996
843b7e5
1d22072
6d62996
ead29e0
 
 
 
ff03388
 
426f91f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e2adc2
 
 
 
a53f74c
1e2adc2
 
 
 
a53f74c
1e2adc2
 
 
 
 
 
 
 
 
a53f74c
1e2adc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc08f05
1e2adc2
 
 
 
 
fc08f05
1e2adc2
fc08f05
9c0b92a
ead29e0
 
9c0b92a
 
ead29e0
 
fc08f05
 
 
 
 
 
 
1e2adc2
 
 
fc08f05
1e2adc2
 
 
fc08f05
1e2adc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import pandas as pd
import requests
import yfinance as yf
from autogluon.timeseries import TimeSeriesPredictor, TimeSeriesDataFrame
import gradio as gr

# Function to fetch stock data
def get_stock_data(ticker, period):
    data = yf.download(ticker, period=period)
    return data

# Function to prepare the data for Chronos-Bolt

def prepare_data_chronos(data):
    # 重設索引並準備數據
    df = data.reset_index()
    
    # 創建符合官方格式的數據框
    formatted_df = pd.DataFrame({
        'item_id': ['stock'] * len(df),
        'timestamp': pd.to_datetime(df['Date']),
        'target': df['Close'].astype('float32').values.ravel()  # 改回使用 'target' 而不是 'value'
    })
    
    # 按照時間戳排序
    formatted_df = formatted_df.sort_values('timestamp')
    
    try:
        # 創建 TimeSeriesDataFrame
        ts_df = TimeSeriesDataFrame.from_data_frame(
            formatted_df,
            id_column='item_id',
            timestamp_column='timestamp'
        )
        return ts_df
    except Exception as e:
        print(f"Error creating TimeSeriesDataFrame: {str(e)}")
        raise


# def prepare_data_chronos(data):
#     # 直接使用收盤價序列
#     series = pd.Series(
#         data['Close'].values,
#         index=data.index,
#         name='value'
#     )
    
#     # 創建基本的時間序列數據框
#     df = pd.DataFrame({
#         'timestamp': series.index,
#         'value': series.values,
#         'item_id': ['stock'] * len(series)
#     })
    
#     return TimeSeriesDataFrame(df)


# Function to fetch stock indices (you already defined these)
def get_tw0050_stocks():
    response = requests.get('https://answerbook.david888.com/TW0050')
    data = response.json()
    return [f"{code}.TW" for code in data['TW0050'].keys()]

def get_sp500_stocks(limit=50):
    response = requests.get('https://answerbook.david888.com/SP500')
    data = response.json()
    return list(data['SP500'].keys())[:limit]

def get_nasdaq_stocks(limit=50):
    response = requests.get('http://13.125.121.198:8090/stocks/NASDAQ100')
    data = response.json()
    return list(data['stocks'].keys())[:limit]

def get_tw0051_stocks():
    response = requests.get('https://answerbook.david888.com/TW0051')
    data = response.json()
    return [f"{code}.TW" for code in data['TW0051'].keys()]

def get_sox_stocks():
    return [
        "NVDA", "AVGO", "GFS", "CRUS", "ON", "ASML", "QCOM", "SWKS", "MPWR", "ADI",
        "TSM", "AMD", "TXN", "QRVO", "AMKR", "MU", "ARM", "NXPI", "TER", "ENTG",
        "LSCC", "COHR", "ONTO", "MTSI", "KLAC", "LRCX", "MRVL", "AMAT", "INTC", "MCHP"
    ]

def get_dji_stocks():
    response = requests.get('http://13.125.121.198:8090/stocks/DOWJONES')
    data = response.json()
    return list(data['stocks'].keys())

# Function to get top 10 potential stocks
def get_top_10_potential_stocks(period, selected_indices):
    stock_list = []
    if "\u53f0\u706350" in selected_indices:
        stock_list += get_tw0050_stocks()
    if "\u53f0\u7063\u4e2d\u578b100" in selected_indices:
        stock_list += get_tw0051_stocks()
    if "S&P\u7cbe\u7c21\u724850" in selected_indices:
        stock_list += get_sp500_stocks()
    if "NASDAQ\u7cbe\u7c21\u724850" in selected_indices:
        stock_list += get_nasdaq_stocks()
    if "\u8cfd\u57ce\u534a\u5b57\u9ad4SOX" in selected_indices:
        stock_list += get_sox_stocks()
    if "\u9053\u74b0DJI" in selected_indices:
        stock_list += get_dji_stocks()

    stock_predictions = []
    prediction_length = 10
    
    for ticker in stock_list:
        try:
            data = get_stock_data(ticker, period)
            if data.empty:
                continue
                
            ts_data = prepare_data_chronos(data)
            
            # 創建預測器,使用 'target' 作為目標列名
            predictor = TimeSeriesPredictor(
                prediction_length=prediction_length,
                freq="D",  # 使用 'D' 而不是 '1D'
                target="target"  # 確保與數據框中的列名一致
            )
            
            predictor.fit(
                ts_data,
                hyperparameters={
                    "Chronos": {"model_path": "autogluon/chronos-bolt-base"}
                }
            )
            
            predictions = predictor.predict(ts_data)
            potential = (predictions.iloc[-1] - data['Close'].iloc[-1]) / data['Close'].iloc[-1]
            stock_predictions.append((ticker, potential, data['Close'].iloc[-1], predictions.iloc[-1]))
            
        except Exception as e:
            print(f"Stock {ticker} error: {str(e)}")
            continue
            
    top_10_stocks = sorted(stock_predictions, key=lambda x: x[1], reverse=True)[:10]
    return top_10_stocks

# Gradio interface function
def stock_prediction_app(period, selected_indices):
    top_10_stocks = get_top_10_potential_stocks(period, selected_indices)
    df = pd.DataFrame(top_10_stocks, columns=["\u80a1\u7968\u4ee3\u865f", "\u6f5b\u529b (\u767e\u5206\u6bd4)", "\u73fe\u50f9", "\u9810\u6e2c\u50f9\u683c"])
    return df

# Define Gradio interface
inputs = [
    gr.Dropdown(choices=["3mo", "6mo", "9mo", "1yr"], label="\u6642\u9593\u7bc4\u570d"),
    gr.CheckboxGroup(choices=["\u53f0\u706350", "\u53f0\u7063\u4e2d\u578b100", "S&P\u7cbe\u7c21\u724850", "NASDAQ\u7cbe\u7c21\u724850", "\u8cfd\u57ce\u534a\u5b57\u9ad4SOX", "\u9053\u74b0DJI"], label="\u6307\u6578\u9078\u64c7", value=["\u53f0\u706350", "\u53f0\u7063\u4e2d\u578b100"])
]
outputs = gr.Dataframe(label="\u6f5b\u529b\u80a1\u63a8\u85a6\u7d50\u679c")

gr.Interface(fn=stock_prediction_app, inputs=inputs, outputs=outputs, title="\u53f0\u80a1\u7f8e\u80a1\u6f5b\u529b\u80a1\u63a8\u85a6\u7cfb\u7d71 - Chronos-Bolt\u6a21\u578b").launch()