tbdavid2019 commited on
Commit
fc08f05
·
1 Parent(s): 4fe64cc

照官方文件改

Browse files
Files changed (1) hide show
  1. app.py +31 -23
app.py CHANGED
@@ -12,26 +12,22 @@ def get_stock_data(ticker, period):
12
  # Function to prepare the data for Chronos-Bolt
13
 
14
  def prepare_data_chronos(data):
15
- # 確保索引重置並重命名欄位
16
- data = data.reset_index()
17
- data = data.rename(columns={"Date": "timestamp", "Close": "target"})
 
 
 
18
 
19
- # 只保留需要的欄位並設定正確的資料類型
20
- data = data[["timestamp", "target"]]
21
- data["item_id"] = "stock"
 
22
 
23
- # 設定正確的資料類型
24
- data["timestamp"] = pd.to_datetime(data["timestamp"])
25
- data["target"] = data["target"].astype('float32')
26
 
27
- # 建立 TimeSeriesDataFrame,只使用必要的參數
28
- ts_data = TimeSeriesDataFrame.from_data_frame(
29
- data,
30
- id_column="item_id",
31
- timestamp_column="timestamp"
32
- )
33
-
34
- return ts_data
35
 
36
 
37
  # Function to fetch stock indices (you already defined these)
@@ -85,25 +81,37 @@ def get_top_10_potential_stocks(period, selected_indices):
85
 
86
  stock_predictions = []
87
  prediction_length = 10
88
-
89
  for ticker in stock_list:
90
  try:
 
91
  data = get_stock_data(ticker, period)
92
  if data.empty:
93
  continue
94
-
 
95
  ts_data = prepare_data_chronos(data)
 
 
96
  predictor = TimeSeriesPredictor(prediction_length=prediction_length)
97
- predictor.fit(ts_data, hyperparameters={"Chronos": {"model_path": "amazon/chronos-bolt-base"}})
98
-
 
 
 
 
 
 
99
  predictions = predictor.predict(ts_data)
 
 
100
  potential = (predictions.iloc[-1] - data['Close'].iloc[-1]) / data['Close'].iloc[-1]
101
  stock_predictions.append((ticker, potential, data['Close'].iloc[-1], predictions.iloc[-1]))
102
-
103
  except Exception as e:
104
  print(f"Stock {ticker} error: {str(e)}")
105
  continue
106
-
107
  top_10_stocks = sorted(stock_predictions, key=lambda x: x[1], reverse=True)[:10]
108
  return top_10_stocks
109
 
 
12
  # Function to prepare the data for Chronos-Bolt
13
 
14
  def prepare_data_chronos(data):
15
+ # Convert to the correct format
16
+ df = data.reset_index()
17
+ df = df.rename(columns={
18
+ 'Date': 'timestamp',
19
+ 'Close': 'target',
20
+ })
21
 
22
+ # Ensure correct data types
23
+ df['timestamp'] = pd.to_datetime(df['timestamp'])
24
+ df['target'] = df['target'].astype('float32')
25
+ df['item_id'] = 'stock'
26
 
27
+ # Create TimeSeriesDataFrame directly
28
+ ts_df = TimeSeriesDataFrame(df)
 
29
 
30
+ return ts_df
 
 
 
 
 
 
 
31
 
32
 
33
  # Function to fetch stock indices (you already defined these)
 
81
 
82
  stock_predictions = []
83
  prediction_length = 10
84
+
85
  for ticker in stock_list:
86
  try:
87
+ # Get stock data
88
  data = get_stock_data(ticker, period)
89
  if data.empty:
90
  continue
91
+
92
+ # Prepare data
93
  ts_data = prepare_data_chronos(data)
94
+
95
+ # Create predictor and fit model
96
  predictor = TimeSeriesPredictor(prediction_length=prediction_length)
97
+ predictor.fit(
98
+ ts_data,
99
+ hyperparameters={
100
+ "Chronos": {"model_path": "autogluon/chronos-bolt-base"}
101
+ }
102
+ )
103
+
104
+ # Make predictions
105
  predictions = predictor.predict(ts_data)
106
+
107
+ # Calculate potential
108
  potential = (predictions.iloc[-1] - data['Close'].iloc[-1]) / data['Close'].iloc[-1]
109
  stock_predictions.append((ticker, potential, data['Close'].iloc[-1], predictions.iloc[-1]))
110
+
111
  except Exception as e:
112
  print(f"Stock {ticker} error: {str(e)}")
113
  continue
114
+
115
  top_10_stocks = sorted(stock_predictions, key=lambda x: x[1], reverse=True)[:10]
116
  return top_10_stocks
117