File size: 5,803 Bytes
47cec11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import pandas as pd
import requests
import yfinance as yf
from autogluon.timeseries import TimeSeriesPredictor, TimeSeriesDataFrame
import gradio as gr

# Function to fetch stock data
def get_stock_data(ticker, period):
    data = yf.download(ticker, period=period)
    return data

# Function to prepare the data for Chronos-Bolt
def prepare_data_chronos(data):
    # Reset index and prepare data
    df = data.reset_index()
    
    # Create a DataFrame in the format expected by AutoGluon TimeSeries
    formatted_df = pd.DataFrame({
        'item_id': ['stock'] * len(df),
        'timestamp': pd.to_datetime(df['Date']),
        'target': df['Close'].astype('float32').values.ravel()
    })
    
    # Sort by timestamp
    formatted_df = formatted_df.sort_values('timestamp')
    
    try:
        # Create a TimeSeriesDataFrame without specifying target_column
        ts_df = TimeSeriesDataFrame.from_data_frame(
            formatted_df,
            id_column='item_id',
            timestamp_column='timestamp'
        )
        return ts_df
    except Exception as e:
        print(f"Error creating TimeSeriesDataFrame: {str(e)}")
        raise

# Functions to fetch stock indices
def get_tw0050_stocks():
    response = requests.get('https://answerbook.david888.com/TW0050')
    data = response.json()
    return [f"{code}.TW" for code in data['TW0050'].keys()]

def get_sp500_stocks(limit=50):
    response = requests.get('https://answerbook.david888.com/SP500')
    data = response.json()
    return list(data['SP500'].keys())[:limit]

def get_nasdaq_stocks(limit=50):
    response = requests.get('http://13.125.121.198:8090/stocks/NASDAQ100')
    data = response.json()
    return list(data['stocks'].keys())[:limit]

def get_tw0051_stocks():
    response = requests.get('https://answerbook.david888.com/TW0051')
    data = response.json()
    return [f"{code}.TW" for code in data['TW0051'].keys()]

def get_sox_stocks():
    return [
        "NVDA", "AVGO", "GFS", "CRUS", "ON", "ASML", "QCOM", "SWKS", "MPWR", "ADI",
        "TSM", "AMD", "TXN", "QRVO", "AMKR", "MU", "ARM", "NXPI", "TER", "ENTG",
        "LSCC", "COHR", "ONTO", "MTSI", "KLAC", "LRCX", "MRVL", "AMAT", "INTC", "MCHP"
    ]

def get_dji_stocks():
    response = requests.get('http://13.125.121.198:8090/stocks/DOWJONES')
    data = response.json()
    return list(data['stocks'].keys())

# Function to get top 10 potential stocks
def get_top_10_potential_stocks(period, selected_indices):
    stock_list = []
    if "\u53f0\u706350" in selected_indices:
        stock_list += get_tw0050_stocks()
    if "\u53f0\u7063\u4e2d\u578b100" in selected_indices:
        stock_list += get_tw0051_stocks()
    if "S&P\u7cbe\u7c21\u724850" in selected_indices:
        stock_list += get_sp500_stocks()
    if "NASDAQ\u7cbe\u7c21\u724850" in selected_indices:
        stock_list += get_nasdaq_stocks()
    if "\u8cfd\u57ce\u534a\u5b57\u9ad4SOX" in selected_indices:
        stock_list += get_sox_stocks()
    if "\u9053\u74b0DJI" in selected_indices:
        stock_list += get_dji_stocks()

    stock_predictions = []
    prediction_length = 10
    
    for ticker in stock_list:
        try:
            data = get_stock_data(ticker, period)
            if data.empty:
                continue
                
            ts_data = prepare_data_chronos(data)
            
            # Create a TimeSeriesPredictor for daily data
            predictor = TimeSeriesPredictor(
                prediction_length=prediction_length,
                freq="1D"
            )
            predictor.fit(
                ts_data,
                hyperparameters={
                    "Chronos": {"model_path": "autogluon/chronos-bolt-base"}
                }
            )
            
            predictions = predictor.predict(ts_data)
            # Calculate potential as (prediction - last_close) / last_close
            potential = (predictions.iloc[-1] - data['Close'].iloc[-1]) / data['Close'].iloc[-1]
            stock_predictions.append((ticker, potential, data['Close'].iloc[-1], predictions.iloc[-1]))
            
        except Exception as e:
            print(f"Stock {ticker} error: {str(e)}")
            continue
            
    # Sort stocks by potential in descending order, take top 10
    top_10_stocks = sorted(stock_predictions, key=lambda x: x[1], reverse=True)[:10]
    return top_10_stocks

# Gradio interface function
def stock_prediction_app(period, selected_indices):
    top_10_stocks = get_top_10_potential_stocks(period, selected_indices)
    df = pd.DataFrame(top_10_stocks, columns=[
        "\u80a1\u7968\u4ee3\u865f",      # Ticker
        "\u6f5b\u529b (\u767e\u5206\u6bd4)",  # Potential
        "\u73fe\u50f9",               # Current Price
        "\u9810\u6e2c\u50f9\u683c"       # Predicted Price
    ])
    return df

# Define Gradio interface
inputs = [
    gr.Dropdown(choices=["3mo", "6mo", "9mo", "1yr"], label="\u6642\u9593\u7bc4\u570d"),
    gr.CheckboxGroup(
        choices=[
            "\u53f0\u706350",              # 台灣50
            "\u53f0\u7063\u4e2d\u578b100",    # 台灣中型100
            "S&P\u7cbe\u7c21\u724850",     # S&P精簡版50
            "NASDAQ\u7cbe\u7c21\u724850",  # NASDAQ精簡版50
            "\u8cfd\u57ce\u534a\u5b57\u9ad4SOX",  # 費城半導體
            "\u9053\u74b0DJI"             # 道瓊DJI
        ],
        label="\u6307\u6578\u9078\u64c7",
        value=["\u53f0\u706350", "\u53f0\u7063\u4e2d\u578b100"]
    )
]

outputs = gr.Dataframe(label="\u6f5b\u529b\u80a1\u63a8\u85a6\u7d50\u679c")

app = gr.Interface(
    fn=stock_prediction_app,
    inputs=inputs,
    outputs=outputs,
    title="\u53f0\u80a1\u7f8e\u80a1\u6f5b\u529b\u80a1\u63a8\u85a6\u7cfb\u7d71 - Chronos-Bolt\u6a21\u578b"
)

if __name__ == "__main__":
    app.launch()