File size: 4,953 Bytes
1e2adc2 a45959e 1e2adc2 0091ee4 1e2adc2 ff03388 1e2adc2 ff03388 1e2adc2 ff03388 1e2adc2 ff03388 0091ee4 ff03388 1e2adc2 a53f74c 1e2adc2 a53f74c 1e2adc2 a53f74c 1e2adc2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import pandas as pd
import requests
import yfinance as yf
from autogluon.timeseries import TimeSeriesPredictor, TimeSeriesDataFrame
import gradio as gr
# Function to fetch stock data
def get_stock_data(ticker, period):
data = yf.download(ticker, period=period)
return data
# Function to prepare the data for Chronos-Bolt
def prepare_data_chronos(data):
# 確保索引重置並重命名欄位
data = data.reset_index()
data = data.rename(columns={"Date": "timestamp", "Close": "target"})
# 只保留需要的欄位
data = data[["timestamp", "target"]]
# 設定正確的資料類型
data = data.astype({
"timestamp": "datetime64[ns]",
"target": "float32"
})
# 添加 item_id
data["item_id"] = "stock"
# 建立 TimeSeriesDataFrame 並指定資料類型
ts_data = TimeSeriesDataFrame(
data,
id_column="item_id",
timestamp_column="timestamp",
target_column="target"
)
# 確保時間序列資料是按時間排序的
ts_data = ts_data.sort_index()
return ts_data
# Function to fetch stock indices (you already defined these)
def get_tw0050_stocks():
response = requests.get('https://answerbook.david888.com/TW0050')
data = response.json()
return [f"{code}.TW" for code in data['TW0050'].keys()]
def get_sp500_stocks(limit=50):
response = requests.get('https://answerbook.david888.com/SP500')
data = response.json()
return list(data['SP500'].keys())[:limit]
def get_nasdaq_stocks(limit=50):
response = requests.get('http://13.125.121.198:8090/stocks/NASDAQ100')
data = response.json()
return list(data['stocks'].keys())[:limit]
def get_tw0051_stocks():
response = requests.get('https://answerbook.david888.com/TW0051')
data = response.json()
return [f"{code}.TW" for code in data['TW0051'].keys()]
def get_sox_stocks():
return [
"NVDA", "AVGO", "GFS", "CRUS", "ON", "ASML", "QCOM", "SWKS", "MPWR", "ADI",
"TSM", "AMD", "TXN", "QRVO", "AMKR", "MU", "ARM", "NXPI", "TER", "ENTG",
"LSCC", "COHR", "ONTO", "MTSI", "KLAC", "LRCX", "MRVL", "AMAT", "INTC", "MCHP"
]
def get_dji_stocks():
response = requests.get('http://13.125.121.198:8090/stocks/DOWJONES')
data = response.json()
return list(data['stocks'].keys())
# Function to get top 10 potential stocks
def get_top_10_potential_stocks(period, selected_indices):
stock_list = []
if "\u53f0\u706350" in selected_indices:
stock_list += get_tw0050_stocks()
if "\u53f0\u7063\u4e2d\u578b100" in selected_indices:
stock_list += get_tw0051_stocks()
if "S&P\u7cbe\u7c21\u724850" in selected_indices:
stock_list += get_sp500_stocks()
if "NASDAQ\u7cbe\u7c21\u724850" in selected_indices:
stock_list += get_nasdaq_stocks()
if "\u8cfd\u57ce\u534a\u5b57\u9ad4SOX" in selected_indices:
stock_list += get_sox_stocks()
if "\u9053\u74b0DJI" in selected_indices:
stock_list += get_dji_stocks()
stock_predictions = []
prediction_length = 10
for ticker in stock_list:
try:
data = get_stock_data(ticker, period)
if data.empty:
continue
ts_data = prepare_data_chronos(data)
predictor = TimeSeriesPredictor(prediction_length=prediction_length)
predictor.fit(ts_data, hyperparameters={"Chronos": {"model_path": "amazon/chronos-bolt-base"}})
predictions = predictor.predict(ts_data)
potential = (predictions.iloc[-1] - data['Close'].iloc[-1]) / data['Close'].iloc[-1]
stock_predictions.append((ticker, potential, data['Close'].iloc[-1], predictions.iloc[-1]))
except Exception as e:
print(f"Stock {ticker} error: {str(e)}")
continue
top_10_stocks = sorted(stock_predictions, key=lambda x: x[1], reverse=True)[:10]
return top_10_stocks
# Gradio interface function
def stock_prediction_app(period, selected_indices):
top_10_stocks = get_top_10_potential_stocks(period, selected_indices)
df = pd.DataFrame(top_10_stocks, columns=["\u80a1\u7968\u4ee3\u865f", "\u6f5b\u529b (\u767e\u5206\u6bd4)", "\u73fe\u50f9", "\u9810\u6e2c\u50f9\u683c"])
return df
# Define Gradio interface
inputs = [
gr.Dropdown(choices=["3mo", "6mo", "9mo", "1yr"], label="\u6642\u9593\u7bc4\u570d"),
gr.CheckboxGroup(choices=["\u53f0\u706350", "\u53f0\u7063\u4e2d\u578b100", "S&P\u7cbe\u7c21\u724850", "NASDAQ\u7cbe\u7c21\u724850", "\u8cfd\u57ce\u534a\u5b57\u9ad4SOX", "\u9053\u74b0DJI"], label="\u6307\u6578\u9078\u64c7", value=["\u53f0\u706350", "\u53f0\u7063\u4e2d\u578b100"])
]
outputs = gr.Dataframe(label="\u6f5b\u529b\u80a1\u63a8\u85a6\u7d50\u679c")
gr.Interface(fn=stock_prediction_app, inputs=inputs, outputs=outputs, title="\u53f0\u80a1\u7f8e\u80a1\u6f5b\u529b\u80a1\u63a8\u85a6\u7cfb\u7d71 - Chronos-Bolt\u6a21\u578b").launch()
|