File size: 5,803 Bytes
47cec11 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import pandas as pd
import requests
import yfinance as yf
from autogluon.timeseries import TimeSeriesPredictor, TimeSeriesDataFrame
import gradio as gr
# Function to fetch stock data
def get_stock_data(ticker, period):
data = yf.download(ticker, period=period)
return data
# Function to prepare the data for Chronos-Bolt
def prepare_data_chronos(data):
# Reset index and prepare data
df = data.reset_index()
# Create a DataFrame in the format expected by AutoGluon TimeSeries
formatted_df = pd.DataFrame({
'item_id': ['stock'] * len(df),
'timestamp': pd.to_datetime(df['Date']),
'target': df['Close'].astype('float32').values.ravel()
})
# Sort by timestamp
formatted_df = formatted_df.sort_values('timestamp')
try:
# Create a TimeSeriesDataFrame without specifying target_column
ts_df = TimeSeriesDataFrame.from_data_frame(
formatted_df,
id_column='item_id',
timestamp_column='timestamp'
)
return ts_df
except Exception as e:
print(f"Error creating TimeSeriesDataFrame: {str(e)}")
raise
# Functions to fetch stock indices
def get_tw0050_stocks():
response = requests.get('https://answerbook.david888.com/TW0050')
data = response.json()
return [f"{code}.TW" for code in data['TW0050'].keys()]
def get_sp500_stocks(limit=50):
response = requests.get('https://answerbook.david888.com/SP500')
data = response.json()
return list(data['SP500'].keys())[:limit]
def get_nasdaq_stocks(limit=50):
response = requests.get('http://13.125.121.198:8090/stocks/NASDAQ100')
data = response.json()
return list(data['stocks'].keys())[:limit]
def get_tw0051_stocks():
response = requests.get('https://answerbook.david888.com/TW0051')
data = response.json()
return [f"{code}.TW" for code in data['TW0051'].keys()]
def get_sox_stocks():
return [
"NVDA", "AVGO", "GFS", "CRUS", "ON", "ASML", "QCOM", "SWKS", "MPWR", "ADI",
"TSM", "AMD", "TXN", "QRVO", "AMKR", "MU", "ARM", "NXPI", "TER", "ENTG",
"LSCC", "COHR", "ONTO", "MTSI", "KLAC", "LRCX", "MRVL", "AMAT", "INTC", "MCHP"
]
def get_dji_stocks():
response = requests.get('http://13.125.121.198:8090/stocks/DOWJONES')
data = response.json()
return list(data['stocks'].keys())
# Function to get top 10 potential stocks
def get_top_10_potential_stocks(period, selected_indices):
stock_list = []
if "\u53f0\u706350" in selected_indices:
stock_list += get_tw0050_stocks()
if "\u53f0\u7063\u4e2d\u578b100" in selected_indices:
stock_list += get_tw0051_stocks()
if "S&P\u7cbe\u7c21\u724850" in selected_indices:
stock_list += get_sp500_stocks()
if "NASDAQ\u7cbe\u7c21\u724850" in selected_indices:
stock_list += get_nasdaq_stocks()
if "\u8cfd\u57ce\u534a\u5b57\u9ad4SOX" in selected_indices:
stock_list += get_sox_stocks()
if "\u9053\u74b0DJI" in selected_indices:
stock_list += get_dji_stocks()
stock_predictions = []
prediction_length = 10
for ticker in stock_list:
try:
data = get_stock_data(ticker, period)
if data.empty:
continue
ts_data = prepare_data_chronos(data)
# Create a TimeSeriesPredictor for daily data
predictor = TimeSeriesPredictor(
prediction_length=prediction_length,
freq="1D"
)
predictor.fit(
ts_data,
hyperparameters={
"Chronos": {"model_path": "autogluon/chronos-bolt-base"}
}
)
predictions = predictor.predict(ts_data)
# Calculate potential as (prediction - last_close) / last_close
potential = (predictions.iloc[-1] - data['Close'].iloc[-1]) / data['Close'].iloc[-1]
stock_predictions.append((ticker, potential, data['Close'].iloc[-1], predictions.iloc[-1]))
except Exception as e:
print(f"Stock {ticker} error: {str(e)}")
continue
# Sort stocks by potential in descending order, take top 10
top_10_stocks = sorted(stock_predictions, key=lambda x: x[1], reverse=True)[:10]
return top_10_stocks
# Gradio interface function
def stock_prediction_app(period, selected_indices):
top_10_stocks = get_top_10_potential_stocks(period, selected_indices)
df = pd.DataFrame(top_10_stocks, columns=[
"\u80a1\u7968\u4ee3\u865f", # Ticker
"\u6f5b\u529b (\u767e\u5206\u6bd4)", # Potential
"\u73fe\u50f9", # Current Price
"\u9810\u6e2c\u50f9\u683c" # Predicted Price
])
return df
# Define Gradio interface
inputs = [
gr.Dropdown(choices=["3mo", "6mo", "9mo", "1yr"], label="\u6642\u9593\u7bc4\u570d"),
gr.CheckboxGroup(
choices=[
"\u53f0\u706350", # 台灣50
"\u53f0\u7063\u4e2d\u578b100", # 台灣中型100
"S&P\u7cbe\u7c21\u724850", # S&P精簡版50
"NASDAQ\u7cbe\u7c21\u724850", # NASDAQ精簡版50
"\u8cfd\u57ce\u534a\u5b57\u9ad4SOX", # 費城半導體
"\u9053\u74b0DJI" # 道瓊DJI
],
label="\u6307\u6578\u9078\u64c7",
value=["\u53f0\u706350", "\u53f0\u7063\u4e2d\u578b100"]
)
]
outputs = gr.Dataframe(label="\u6f5b\u529b\u80a1\u63a8\u85a6\u7d50\u679c")
app = gr.Interface(
fn=stock_prediction_app,
inputs=inputs,
outputs=outputs,
title="\u53f0\u80a1\u7f8e\u80a1\u6f5b\u529b\u80a1\u63a8\u85a6\u7cfb\u7d71 - Chronos-Bolt\u6a21\u578b"
)
if __name__ == "__main__":
app.launch() |