File size: 4,953 Bytes
1e2adc2
 
a45959e
1e2adc2
 
 
 
 
 
 
 
 
0091ee4
1e2adc2
ff03388
1e2adc2
 
ff03388
 
1e2adc2
ff03388
 
 
 
 
 
 
 
1e2adc2
ff03388
 
 
0091ee4
 
 
 
 
ff03388
 
 
 
 
 
 
1e2adc2
 
 
 
a53f74c
1e2adc2
 
 
 
a53f74c
1e2adc2
 
 
 
 
 
 
 
 
a53f74c
1e2adc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import pandas as pd
import requests
import yfinance as yf
from autogluon.timeseries import TimeSeriesPredictor, TimeSeriesDataFrame
import gradio as gr

# Function to fetch stock data
def get_stock_data(ticker, period):
    data = yf.download(ticker, period=period)
    return data

# Function to prepare the data for Chronos-Bolt

def prepare_data_chronos(data):
    # 確保索引重置並重命名欄位
    data = data.reset_index()
    data = data.rename(columns={"Date": "timestamp", "Close": "target"})
    
    # 只保留需要的欄位
    data = data[["timestamp", "target"]]
    
    # 設定正確的資料類型
    data = data.astype({
        "timestamp": "datetime64[ns]",
        "target": "float32"
    })
    
    # 添加 item_id
    data["item_id"] = "stock"
    
    # 建立 TimeSeriesDataFrame 並指定資料類型
    ts_data = TimeSeriesDataFrame(
        data,
        id_column="item_id",
        timestamp_column="timestamp",
        target_column="target"
    )
    
    # 確保時間序列資料是按時間排序的
    ts_data = ts_data.sort_index()
    
    return ts_data


# Function to fetch stock indices (you already defined these)
def get_tw0050_stocks():
    response = requests.get('https://answerbook.david888.com/TW0050')
    data = response.json()
    return [f"{code}.TW" for code in data['TW0050'].keys()]

def get_sp500_stocks(limit=50):
    response = requests.get('https://answerbook.david888.com/SP500')
    data = response.json()
    return list(data['SP500'].keys())[:limit]

def get_nasdaq_stocks(limit=50):
    response = requests.get('http://13.125.121.198:8090/stocks/NASDAQ100')
    data = response.json()
    return list(data['stocks'].keys())[:limit]

def get_tw0051_stocks():
    response = requests.get('https://answerbook.david888.com/TW0051')
    data = response.json()
    return [f"{code}.TW" for code in data['TW0051'].keys()]

def get_sox_stocks():
    return [
        "NVDA", "AVGO", "GFS", "CRUS", "ON", "ASML", "QCOM", "SWKS", "MPWR", "ADI",
        "TSM", "AMD", "TXN", "QRVO", "AMKR", "MU", "ARM", "NXPI", "TER", "ENTG",
        "LSCC", "COHR", "ONTO", "MTSI", "KLAC", "LRCX", "MRVL", "AMAT", "INTC", "MCHP"
    ]

def get_dji_stocks():
    response = requests.get('http://13.125.121.198:8090/stocks/DOWJONES')
    data = response.json()
    return list(data['stocks'].keys())

# Function to get top 10 potential stocks
def get_top_10_potential_stocks(period, selected_indices):
    stock_list = []
    if "\u53f0\u706350" in selected_indices:
        stock_list += get_tw0050_stocks()
    if "\u53f0\u7063\u4e2d\u578b100" in selected_indices:
        stock_list += get_tw0051_stocks()
    if "S&P\u7cbe\u7c21\u724850" in selected_indices:
        stock_list += get_sp500_stocks()
    if "NASDAQ\u7cbe\u7c21\u724850" in selected_indices:
        stock_list += get_nasdaq_stocks()
    if "\u8cfd\u57ce\u534a\u5b57\u9ad4SOX" in selected_indices:
        stock_list += get_sox_stocks()
    if "\u9053\u74b0DJI" in selected_indices:
        stock_list += get_dji_stocks()

    stock_predictions = []
    prediction_length = 10

    for ticker in stock_list:
        try:
            data = get_stock_data(ticker, period)
            if data.empty:
                continue

            ts_data = prepare_data_chronos(data)
            predictor = TimeSeriesPredictor(prediction_length=prediction_length)
            predictor.fit(ts_data, hyperparameters={"Chronos": {"model_path": "amazon/chronos-bolt-base"}})

            predictions = predictor.predict(ts_data)
            potential = (predictions.iloc[-1] - data['Close'].iloc[-1]) / data['Close'].iloc[-1]
            stock_predictions.append((ticker, potential, data['Close'].iloc[-1], predictions.iloc[-1]))

        except Exception as e:
            print(f"Stock {ticker} error: {str(e)}")
            continue

    top_10_stocks = sorted(stock_predictions, key=lambda x: x[1], reverse=True)[:10]
    return top_10_stocks

# Gradio interface function
def stock_prediction_app(period, selected_indices):
    top_10_stocks = get_top_10_potential_stocks(period, selected_indices)
    df = pd.DataFrame(top_10_stocks, columns=["\u80a1\u7968\u4ee3\u865f", "\u6f5b\u529b (\u767e\u5206\u6bd4)", "\u73fe\u50f9", "\u9810\u6e2c\u50f9\u683c"])
    return df

# Define Gradio interface
inputs = [
    gr.Dropdown(choices=["3mo", "6mo", "9mo", "1yr"], label="\u6642\u9593\u7bc4\u570d"),
    gr.CheckboxGroup(choices=["\u53f0\u706350", "\u53f0\u7063\u4e2d\u578b100", "S&P\u7cbe\u7c21\u724850", "NASDAQ\u7cbe\u7c21\u724850", "\u8cfd\u57ce\u534a\u5b57\u9ad4SOX", "\u9053\u74b0DJI"], label="\u6307\u6578\u9078\u64c7", value=["\u53f0\u706350", "\u53f0\u7063\u4e2d\u578b100"])
]
outputs = gr.Dataframe(label="\u6f5b\u529b\u80a1\u63a8\u85a6\u7d50\u679c")

gr.Interface(fn=stock_prediction_app, inputs=inputs, outputs=outputs, title="\u53f0\u80a1\u7f8e\u80a1\u6f5b\u529b\u80a1\u63a8\u85a6\u7cfb\u7d71 - Chronos-Bolt\u6a21\u578b").launch()