File size: 4,867 Bytes
1e2adc2
 
a45959e
1e2adc2
 
 
 
 
 
 
 
 
0091ee4
1e2adc2
ff03388
1e2adc2
 
ff03388
4fe64cc
1e2adc2
4fe64cc
ff03388
 
4fe64cc
 
ff03388
4fe64cc
 
0091ee4
 
4fe64cc
0091ee4
ff03388
 
 
 
1e2adc2
 
 
 
a53f74c
1e2adc2
 
 
 
a53f74c
1e2adc2
 
 
 
 
 
 
 
 
a53f74c
1e2adc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import pandas as pd
import requests
import yfinance as yf
from autogluon.timeseries import TimeSeriesPredictor, TimeSeriesDataFrame
import gradio as gr

# Function to fetch stock data
def get_stock_data(ticker, period):
    data = yf.download(ticker, period=period)
    return data

# Function to prepare the data for Chronos-Bolt

def prepare_data_chronos(data):
    # 確保索引重置並重命名欄位
    data = data.reset_index()
    data = data.rename(columns={"Date": "timestamp", "Close": "target"})
    
    # 只保留需要的欄位並設定正確的資料類型
    data = data[["timestamp", "target"]]
    data["item_id"] = "stock"
    
    # 設定正確的資料類型
    data["timestamp"] = pd.to_datetime(data["timestamp"])
    data["target"] = data["target"].astype('float32')
    
    # 建立 TimeSeriesDataFrame,只使用必要的參數
    ts_data = TimeSeriesDataFrame.from_data_frame(
        data,
        id_column="item_id",
        timestamp_column="timestamp"
    )
    
    return ts_data


# Function to fetch stock indices (you already defined these)
def get_tw0050_stocks():
    response = requests.get('https://answerbook.david888.com/TW0050')
    data = response.json()
    return [f"{code}.TW" for code in data['TW0050'].keys()]

def get_sp500_stocks(limit=50):
    response = requests.get('https://answerbook.david888.com/SP500')
    data = response.json()
    return list(data['SP500'].keys())[:limit]

def get_nasdaq_stocks(limit=50):
    response = requests.get('http://13.125.121.198:8090/stocks/NASDAQ100')
    data = response.json()
    return list(data['stocks'].keys())[:limit]

def get_tw0051_stocks():
    response = requests.get('https://answerbook.david888.com/TW0051')
    data = response.json()
    return [f"{code}.TW" for code in data['TW0051'].keys()]

def get_sox_stocks():
    return [
        "NVDA", "AVGO", "GFS", "CRUS", "ON", "ASML", "QCOM", "SWKS", "MPWR", "ADI",
        "TSM", "AMD", "TXN", "QRVO", "AMKR", "MU", "ARM", "NXPI", "TER", "ENTG",
        "LSCC", "COHR", "ONTO", "MTSI", "KLAC", "LRCX", "MRVL", "AMAT", "INTC", "MCHP"
    ]

def get_dji_stocks():
    response = requests.get('http://13.125.121.198:8090/stocks/DOWJONES')
    data = response.json()
    return list(data['stocks'].keys())

# Function to get top 10 potential stocks
def get_top_10_potential_stocks(period, selected_indices):
    stock_list = []
    if "\u53f0\u706350" in selected_indices:
        stock_list += get_tw0050_stocks()
    if "\u53f0\u7063\u4e2d\u578b100" in selected_indices:
        stock_list += get_tw0051_stocks()
    if "S&P\u7cbe\u7c21\u724850" in selected_indices:
        stock_list += get_sp500_stocks()
    if "NASDAQ\u7cbe\u7c21\u724850" in selected_indices:
        stock_list += get_nasdaq_stocks()
    if "\u8cfd\u57ce\u534a\u5b57\u9ad4SOX" in selected_indices:
        stock_list += get_sox_stocks()
    if "\u9053\u74b0DJI" in selected_indices:
        stock_list += get_dji_stocks()

    stock_predictions = []
    prediction_length = 10

    for ticker in stock_list:
        try:
            data = get_stock_data(ticker, period)
            if data.empty:
                continue

            ts_data = prepare_data_chronos(data)
            predictor = TimeSeriesPredictor(prediction_length=prediction_length)
            predictor.fit(ts_data, hyperparameters={"Chronos": {"model_path": "amazon/chronos-bolt-base"}})

            predictions = predictor.predict(ts_data)
            potential = (predictions.iloc[-1] - data['Close'].iloc[-1]) / data['Close'].iloc[-1]
            stock_predictions.append((ticker, potential, data['Close'].iloc[-1], predictions.iloc[-1]))

        except Exception as e:
            print(f"Stock {ticker} error: {str(e)}")
            continue

    top_10_stocks = sorted(stock_predictions, key=lambda x: x[1], reverse=True)[:10]
    return top_10_stocks

# Gradio interface function
def stock_prediction_app(period, selected_indices):
    top_10_stocks = get_top_10_potential_stocks(period, selected_indices)
    df = pd.DataFrame(top_10_stocks, columns=["\u80a1\u7968\u4ee3\u865f", "\u6f5b\u529b (\u767e\u5206\u6bd4)", "\u73fe\u50f9", "\u9810\u6e2c\u50f9\u683c"])
    return df

# Define Gradio interface
inputs = [
    gr.Dropdown(choices=["3mo", "6mo", "9mo", "1yr"], label="\u6642\u9593\u7bc4\u570d"),
    gr.CheckboxGroup(choices=["\u53f0\u706350", "\u53f0\u7063\u4e2d\u578b100", "S&P\u7cbe\u7c21\u724850", "NASDAQ\u7cbe\u7c21\u724850", "\u8cfd\u57ce\u534a\u5b57\u9ad4SOX", "\u9053\u74b0DJI"], label="\u6307\u6578\u9078\u64c7", value=["\u53f0\u706350", "\u53f0\u7063\u4e2d\u578b100"])
]
outputs = gr.Dataframe(label="\u6f5b\u529b\u80a1\u63a8\u85a6\u7d50\u679c")

gr.Interface(fn=stock_prediction_app, inputs=inputs, outputs=outputs, title="\u53f0\u80a1\u7f8e\u80a1\u6f5b\u529b\u80a1\u63a8\u85a6\u7cfb\u7d71 - Chronos-Bolt\u6a21\u578b").launch()