# ================================ # 版本切换开关 - 从配置文件导入 # ================================ from data_source_config import USE_YFINANCE_VERSION, API_TIMEOUT_SECONDS, MAX_RETRY_ATTEMPTS import logging import re import pandas as pd from datetime import datetime, timedelta import time # 导入标准库的 time 模块 import os import requests import threading import asyncio # 根据开关导入不同的模块 if USE_YFINANCE_VERSION: import yfinance as yf print("🔄 Using yfinance version (new)") else: import akshare as ak print("🔄 Using akshare version (old)") logging.basicConfig(level=logging.INFO) # 获取当前文件的目录 base_dir = os.path.dirname(os.path.abspath(__file__)) # 构建CSV文件的绝对路径 nasdaq_100_path = os.path.join(base_dir, './model/nasdaq100.csv') dow_jones_path = os.path.join(base_dir, './model/dji.csv') sp500_path = os.path.join(base_dir, './model/sp500.csv') nasdaq_composite_path = os.path.join(base_dir, './model/nasdaq_all.csv') # 从CSV文件加载成分股数据 nasdaq_100_stocks = pd.read_csv(nasdaq_100_path) dow_jones_stocks = pd.read_csv(dow_jones_path) sp500_stocks = pd.read_csv(sp500_path) nasdaq_composite_stocks = pd.read_csv(nasdaq_composite_path) def fetch_stock_us_spot_data_with_retries(): """根据开关选择不同的数据源获取股票列表""" if USE_YFINANCE_VERSION: return fetch_stock_us_spot_data_yfinance() else: return fetch_stock_us_spot_data_akshare() def fetch_stock_us_spot_data_akshare(): """原始的 akshare 实现""" if not USE_YFINANCE_VERSION: # 定义重试间隔时间序列(秒) retry_intervals = [10, 20, 60, 300, 600] retry_index = 0 # 初始重试序号 while True: try: # 尝试获取API数据 symbols = ak.stock_us_spot_em() return symbols # 成功获取数据后返回 except Exception as e: print(f"Error fetching data: {e}") # 获取当前重试等待时间 wait_time = retry_intervals[retry_index] print(f"Retrying in {wait_time} seconds...") time.sleep(wait_time) # 等待指定的秒数 # 更新重试索引,但不要超出重试时间列表的范围 retry_index = min(retry_index + 1, len(retry_intervals) - 1) else: print("Warning: akshare function called while using yfinance version") return pd.DataFrame() def fetch_stock_us_spot_data_yfinance(): """新的 yfinance 实现""" try: # 从本地CSV文件收集所有股票代码 all_symbols = set() # 从各个指数CSV文件中提取股票代码 for df, name in [ (nasdaq_100_stocks, "NASDAQ-100"), (dow_jones_stocks, "Dow Jones"), (sp500_stocks, "S&P 500"), (nasdaq_composite_stocks, "NASDAQ Composite") ]: if 'Symbol' in df.columns: symbols_from_csv = df['Symbol'].dropna().astype(str).tolist() all_symbols.update(symbols_from_csv) elif 'Code' in df.columns: symbols_from_csv = df['Code'].dropna().astype(str).tolist() all_symbols.update(symbols_from_csv) # 添加一些常见的ETF和热门股票 additional_symbols = [ # 主要ETF 'SPY', 'QQQ', 'IWM', 'VTI', 'ARKK', 'TQQQ', 'SQQQ', 'SPXL', # 热门科技股 'AAPL', 'MSFT', 'GOOGL', 'GOOG', 'AMZN', 'TSLA', 'META', 'NVDA', 'NFLX', 'AMD', 'INTC', 'ORCL', 'CRM', 'ADBE', 'PYPL', 'UBER', 'LYFT', # 中概股 'BABA', 'JD', 'PDD', 'NIO', 'XPEV', 'LI', 'DIDI', 'TME', # 其他热门股票 'COST', 'WMT', 'JPM', 'BAC', 'XOM', 'CVX', 'PFE', 'JNJ', 'KO', 'PEP' ] all_symbols.update(additional_symbols) # 创建DataFrame symbols_list = sorted(list(all_symbols)) symbols_df = pd.DataFrame({ '代码': symbols_list, '名称': [f'{symbol} Inc.' for symbol in symbols_list] # 简单的名称映射 }) print(f"Created symbols dataframe with {len(symbols_df)} symbols using yfinance version") return symbols_df except Exception as e: print(f"Error creating symbols dataframe: {e}") # 返回基本的fallback数据 fallback_symbols = [ 'AAPL', 'MSFT', 'GOOGL', 'AMZN', 'TSLA', 'META', 'NVDA', 'NFLX', 'SPY', 'QQQ', 'IWM', 'VTI' ] return pd.DataFrame({ '代码': fallback_symbols, '名称': [f'{symbol} Inc.' for symbol in fallback_symbols] }) async def fetch_stock_us_spot_data_with_retries_async(): """异步版本的股票数据获取,支持版本切换""" if USE_YFINANCE_VERSION: try: return await asyncio.to_thread(fetch_stock_us_spot_data_yfinance) except Exception as e: print(f"Error in async yfinance fetch: {e}") return pd.DataFrame() else: return await fetch_stock_us_spot_data_akshare_async() async def fetch_stock_us_spot_data_akshare_async(): """原始的 akshare 异步实现""" if not USE_YFINANCE_VERSION: retry_intervals = [10, 20] # 减少重试次数 retry_index = 0 max_retries = 2 # 最多重试2次 for attempt in range(max_retries + 1): try: # 添加30秒超时 symbols = await asyncio.wait_for( asyncio.to_thread(ak.stock_us_spot_em), timeout=30.0 ) return symbols except asyncio.TimeoutError: print(f"Timeout error fetching data (attempt {attempt + 1}/{max_retries + 1})") except Exception as e: print(f"Error fetching data (attempt {attempt + 1}/{max_retries + 1}): {e}") if attempt < max_retries: wait_time = retry_intervals[min(retry_index, len(retry_intervals) - 1)] print(f"Retrying in {wait_time} seconds...") await asyncio.sleep(wait_time) retry_index += 1 # 如果所有重试都失败,返回空数据 print("All retries failed, returning empty data") return pd.DataFrame() else: print("Warning: akshare async function called while using yfinance version") return pd.DataFrame() symbols = None def create_fallback_symbols(): """创建fallback符号数据,用于测试""" fallback_symbols = [ 'AAPL', 'MSFT', 'GOOGL', 'AMZN', 'TSLA', 'META', 'NVDA', 'NFLX', 'SPY', 'QQQ', 'IWM', 'VTI' ] return pd.DataFrame({ '代码': fallback_symbols, '名称': [f'{symbol} Inc.' for symbol in fallback_symbols] }) async def fetch_symbols(): global symbols try: print("Starting symbols initialization...") # 异步获取数据 symbols = await fetch_stock_us_spot_data_with_retries_async() if symbols is not None and not symbols.empty: print(f"Symbols initialized successfully: {len(symbols)} symbols loaded") else: print("Symbols initialization failed, using empty dataset") symbols = pd.DataFrame() except Exception as e: print(f"Error in fetch_symbols: {e}") symbols = pd.DataFrame() finally: print("Symbols initialization completed") # 全局变量 index_us_stock_index_INX = None index_us_stock_index_DJI = None index_us_stock_index_IXIC = None index_us_stock_index_NDX = None def update_stock_indices(): global index_us_stock_index_INX, index_us_stock_index_DJI, index_us_stock_index_IXIC, index_us_stock_index_NDX try: print("Starting stock indices update...") if USE_YFINANCE_VERSION: print("Updating indices using yfinance...") # 使用 yfinance 更新指数数据 from datetime import datetime, timedelta # 计算日期范围 end_date = datetime.now() start_date = end_date - timedelta(weeks=8) # 定义指数映射 indices = { '^GSPC': 'INX', # S&P 500 '^DJI': 'DJI', # Dow Jones '^IXIC': 'IXIC', # NASDAQ Composite '^NDX': 'NDX' # NASDAQ 100 } for yf_symbol, var_name in indices.items(): try: ticker = yf.Ticker(yf_symbol) hist_data = ticker.history(start=start_date, end=end_date) if not hist_data.empty: # 转换为与akshare相同的格式 formatted_data = pd.DataFrame({ 'date': hist_data.index.strftime('%Y-%m-%d'), '开盘': hist_data['Open'].values, '收盘': hist_data['Close'].values, '最高': hist_data['High'].values, '最低': hist_data['Low'].values, '成交量': hist_data['Volume'].values, '成交额': (hist_data['Close'] * hist_data['Volume']).values }) # 设置全局变量 if var_name == 'INX': index_us_stock_index_INX = formatted_data elif var_name == 'DJI': index_us_stock_index_DJI = formatted_data elif var_name == 'IXIC': index_us_stock_index_IXIC = formatted_data elif var_name == 'NDX': index_us_stock_index_NDX = formatted_data print(f"Successfully updated {var_name}: {len(formatted_data)} records") else: print(f"No data received for {yf_symbol}") except Exception as e: print(f"Error fetching {yf_symbol}: {e}") else: print("Updating indices using akshare...") # 使用 akshare 更新指数数据 index_us_stock_index_INX = ak.index_us_stock_sina(symbol=".INX") index_us_stock_index_DJI = ak.index_us_stock_sina(symbol=".DJI") index_us_stock_index_IXIC = ak.index_us_stock_sina(symbol=".IXIC") index_us_stock_index_NDX = ak.index_us_stock_sina(symbol=".NDX") print("Stock indices updated successfully") except Exception as e: print(f"Error updating stock indices: {e}") # 设置定时器,每隔12小时更新一次 threading.Timer(12 * 60 * 60, update_stock_indices).start() # 程序开始时不立即更新,而是延迟启动 def start_indices_update(): """延迟启动股票指数更新,避免阻塞应用启动""" threading.Timer(5, update_stock_indices).start() # 5秒后开始第一次更新 # 延迟启动股票指数更新 start_indices_update() # 创建列名转换的字典 column_mapping = { '日期': 'date', '开盘': 'open', '收盘': 'close', '最高': 'high', '最低': 'low', '成交量': 'volume', '成交额': 'amount', '振幅': 'amplitude', '涨跌幅': 'price_change_percentage', '涨跌额': 'price_change_amount', '换手率': 'turnover_rate' } # 定义一个标准的列顺序 standard_columns = ['date', 'open', 'close', 'high', 'low', 'volume', 'amount'] # 定义查找函数 def find_stock_entry(stock_code): # 使用 str.endswith 来匹配股票代码 matching_row = symbols[symbols['代码'].str.endswith(stock_code)] # print(symbols) if not matching_row.empty: # print(f"股票代码 {stock_code} 找到, 代码为 {matching_row['代码'].values[0]}") return matching_row['代码'].values[0] else: return "" ''' # 示例调用 # 测试函数 result = find_stock_entry('AAPL') if isinstance(result, pd.DataFrame) and not result.empty: # 如果找到的结果不为空,获取代码列的值 code_value = result['代码'].values[0] print(code_value) else: print(result) ''' def reduce_columns(df, columns_to_keep): return df[columns_to_keep] # 创建缓存字典 _price_cache = {} def get_last_minute_stock_price(symbol: str, max_retries=3) -> float: """获取股票最新价格,使用30分钟缓存,并包含重试机制""" if not symbol: return -1.0 if symbol == "NONE_SYMBOL_FOUND": return -1.0 current_time = datetime.now() # 检查缓存 if symbol in _price_cache: cached_price, cached_time = _price_cache[symbol] # 如果缓存时间在30分钟内,直接返回缓存的价格 if current_time - cached_time < timedelta(minutes=30): return cached_price # 重试机制 for attempt in range(max_retries): try: # 缓存无效或不存在,从yfinance获取新数据 if USE_YFINANCE_VERSION: stock_data = yf.download( symbol, period='1d', interval='5m', progress=False, # 禁用进度条 timeout=10 # 设置超时时间 ) else: # 使用akshare获取数据的逻辑 ticker = ak.stock_us_hist(symbol=symbol, period="daily", start_date="20240101", end_date="20240201") stock_data = ticker if not ticker.empty else pd.DataFrame() if stock_data.empty: print(f"Warning: Empty data received for {symbol}, attempt {attempt + 1}/{max_retries}") if attempt == max_retries - 1: return -1.0 time.sleep(1) # 等待1秒后重试 continue latest_price = float(stock_data['Close'].iloc[-1]) # 更新缓存 _price_cache[symbol] = (latest_price, current_time) return latest_price except Exception as e: print(f"Error fetching price for {symbol}, attempt {attempt + 1}/{max_retries}: {str(e)}") if attempt == max_retries - 1: return -1.0 time.sleep(1) # 等待1秒后重试 return -1.0 # 返回个股历史数据 def get_stock_history(symbol, news_date, retries=10): # 定义重试间隔时间序列(秒) retry_intervals = [10, 20, 60, 300, 600] retry_count = 0 # 如果传入的symbol不包含数字前缀,则通过 find_stock_entry 获取完整的symbol if not any(char.isdigit() for char in symbol): full_symbol = find_stock_entry(symbol) if len(symbol) != 0 and full_symbol: symbol = full_symbol else: symbol = "" # 将news_date转换为datetime对象 current_date = datetime.now() # 计算start_date和end_date start_date = (current_date - timedelta(days=60)).strftime("%Y%m%d") end_date = current_date.strftime("%Y%m%d") stock_hist_df = None retry_index = 0 # 初始化重试索引 while retry_count <= retries and len(symbol) != 0: # 无限循环重试 try: # 根据版本开关选择不同的API if USE_YFINANCE_VERSION: # 使用 yfinance 获取数据 ticker = yf.Ticker(symbol) # 将日期格式转换为 yfinance 期望的格式 (YYYY-MM-DD) yf_start_date = datetime.strptime(start_date, "%Y%m%d").strftime("%Y-%m-%d") yf_end_date = datetime.strptime(end_date, "%Y%m%d").strftime("%Y-%m-%d") stock_hist_df = ticker.history(start=yf_start_date, end=yf_end_date) if not stock_hist_df.empty: # 转换为与akshare相同的格式 stock_hist_df = stock_hist_df.reset_index() stock_hist_df = pd.DataFrame({ 'date': stock_hist_df['Date'].dt.strftime('%Y-%m-%d'), '开盘': stock_hist_df['Open'], '收盘': stock_hist_df['Close'], '最高': stock_hist_df['High'], '最低': stock_hist_df['Low'], '成交量': stock_hist_df['Volume'], '成交额': stock_hist_df['Close'] * stock_hist_df['Volume'], '振幅': 0, # yfinance没有直接提供,设为0 '涨跌幅': 0, # 可以计算,但这里简化为0 '涨跌额': 0, # 可以计算,但这里简化为0 '换手率': 0 # yfinance没有直接提供,设为0 }) else: stock_hist_df = None else: # 使用 akshare 获取数据 stock_hist_df = ak.stock_us_hist(symbol=symbol, period="daily", start_date=start_date, end_date=end_date, adjust="") if stock_hist_df is None or stock_hist_df.empty: # 检查是否为空数据 # print(f"No data for {symbol} on {news_date}.") stock_hist_df = None # 将 DataFrame 设置为 None break except (requests.exceptions.Timeout, ConnectionError) as e: print(f"Request timed out: {e}. Retrying...") retry_count += 1 # 增加重试次数 continue except (TypeError, ValueError, BaseException) as e: print(f"Error {e} scraping data for {symbol} on {news_date}. Break...") # 可能是没数据,直接Break break # 如果发生异常,等待一段时间再重试 wait_time = retry_intervals[retry_index] print(f"Waiting for {wait_time} seconds before retrying...") time.sleep(wait_time) retry_index = (retry_index + 1) if retry_index < len(retry_intervals) - 1 else retry_index # 更新重试索引,不超过列表长度 # 如果获取失败或数据为空,返回填充为0的 DataFrame if stock_hist_df is None or stock_hist_df.empty: # 构建一个空的 DataFrame,包含指定日期范围的空数据 date_range = pd.date_range(start=start_date, end=end_date) stock_hist_df = pd.DataFrame({ 'date': date_range, '开盘': 0, '收盘': 0, '最高': 0, '最低': 0, '成交量': 0, '成交额': 0, '振幅': 0, '涨跌幅': 0, '涨跌额': 0, '换手率': 0 }) # 使用rename方法转换列名 stock_hist_df = stock_hist_df.rename(columns=column_mapping) stock_hist_df = stock_hist_df.reindex(columns=standard_columns) # 处理个股数据,保留所需列 stock_hist_df = reduce_columns(stock_hist_df, standard_columns) return stock_hist_df # 统一列名 stock_hist_df = stock_hist_df.rename(columns=column_mapping) stock_hist_df = stock_hist_df.reindex(columns=standard_columns) # 处理个股数据,保留所需列 stock_hist_df = reduce_columns(stock_hist_df, standard_columns) return stock_hist_df ''' # 示例调用 result = get_stock_history('AAPL', '20240214') print(result) ''' # result = get_stock_history('ATMU', '20231218') # print(result) # 返回个股所属指数历史数据 def get_stock_index_history(symbol, news_date, force_index=0): # 检查股票所属的指数 if symbol in nasdaq_100_stocks['Symbol'].values or force_index == 1: index_code = ".NDX" index_data = index_us_stock_index_NDX elif symbol in dow_jones_stocks['Symbol'].values or force_index == 2: index_code = ".DJI" index_data = index_us_stock_index_DJI elif symbol in sp500_stocks['Symbol'].values or force_index == 3: index_code = ".INX" index_data = index_us_stock_index_INX elif symbol in nasdaq_composite_stocks["Symbol"].values or symbol is None or symbol == "" or force_index == 4: index_code = ".IXIC" index_data = index_us_stock_index_IXIC else: # print(f"股票代码 {symbol} 不属于纳斯达克100、道琼斯工业、标准普尔500或纳斯达克综合指数。") index_code = ".IXIC" index_data = index_us_stock_index_IXIC # 获取当前日期 current_date = datetime.now() # 计算 start_date 和 end_date start_date = (current_date - timedelta(weeks=8)).strftime("%Y-%m-%d") end_date = current_date.strftime("%Y-%m-%d") # 确保 index_data['date'] 是 datetime 类型 index_data['date'] = pd.to_datetime(index_data['date']) # 从指数历史数据中提取指定日期范围的数据 index_hist_df = index_data[(index_data['date'] >= start_date) & (index_data['date'] <= end_date)] # 统一列名 index_hist_df = index_hist_df.rename(columns=column_mapping) index_hist_df = index_hist_df.reindex(columns=standard_columns) # 处理个股数据,保留所需列 index_hist_df = reduce_columns(index_hist_df, standard_columns) return index_hist_df ''' # 示例调用 result = get_stock_index_history('AAPL', '20240214') print(result) ''' def find_stock_codes_or_names(entities): """ 从给定的实体列表中检索股票代码或公司名称。 :param entities: 命名实体识别结果列表,格式为 [('实体名称', '实体类型'), ...] :return: 相关的股票代码列表 """ stock_codes = set() # 合并所有股票字典并清理数据,确保都是字符串 all_symbols = pd.concat([nasdaq_100_stocks['Symbol'], dow_jones_stocks['Symbol'], sp500_stocks['Symbol'], nasdaq_composite_stocks['Symbol']]).dropna().astype(str).unique().tolist() all_names = pd.concat([nasdaq_100_stocks['Name'], nasdaq_composite_stocks['Name'], sp500_stocks['Security'], dow_jones_stocks['Company']]).dropna().astype(str).unique().tolist() # 创建一个 Name 到 Symbol 的映射 name_to_symbol = {} for idx, name in enumerate(all_names): if idx < len(all_symbols): symbol = all_symbols[idx] name_to_symbol[name.lower()] = symbol # 查找实体映射到的股票代码 for entity, entity_type in entities: entity_lower = entity.lower() entity_upper = entity.upper() # 检查 Symbol 列 if entity_upper in all_symbols: stock_codes.add(entity_upper) #print(f"Matched symbol: {entity_upper}") # 检查 Name 列,确保完整匹配而不是部分匹配 for name, symbol in name_to_symbol.items(): # 使用正则表达式进行严格匹配 pattern = rf'\b{re.escape(entity_lower)}\b' if re.search(pattern, name): stock_codes.add(symbol.upper()) #print(f"Matched name/company: '{entity_lower}' in '{name}' -> {symbol.upper()}") #print(f"Stock codes found: {stock_codes}") if not stock_codes: return ['NONE_SYMBOL_FOUND'] return list(stock_codes) def process_history(stock_history, target_date, history_days=30, following_days=3): # 检查数据是否为空 if stock_history.empty: return create_empty_data(history_days), create_empty_data(following_days) # 确保日期列存在并转换为datetime格式 if 'date' not in stock_history.columns: return create_empty_data(history_days), create_empty_data(following_days) stock_history['date'] = pd.to_datetime(stock_history['date']) target_date = pd.to_datetime(target_date) # 按日期升序排序 stock_history = stock_history.sort_values('date') # 找到目标日期对应的索引 target_row = stock_history[stock_history['date'] <= target_date] if target_row.empty: return create_empty_data(history_days), create_empty_data(following_days) # 获取目标日期最近的行 target_index = target_row.index[-1] target_pos = stock_history.index.get_loc(target_index) # 获取历史数据(包括目标日期) start_pos = max(0, target_pos - history_days + 1) previous_rows = stock_history.iloc[start_pos:target_pos + 1] # 获取后续数据 following_rows = stock_history.iloc[target_pos + 1:target_pos + following_days + 1] # 删除日期列并确保数据完整性 previous_rows = previous_rows.drop(columns=['date']) following_rows = following_rows.drop(columns=['date']) # 处理数据不足的情况 previous_rows = handle_insufficient_data(previous_rows, history_days) following_rows = handle_insufficient_data(following_rows, following_days) return previous_rows.iloc[:, :6], following_rows.iloc[:, :6] def create_empty_data(days): return pd.DataFrame({ '开盘': [-1] * days, '收盘': [-1] * days, '最高': [-1] * days, '最低': [-1] * days, '成交量': [-1] * days, '成交额': [-1] * days }) def handle_insufficient_data(data, required_days): current_rows = len(data) if current_rows < required_days: missing_rows = required_days - current_rows empty_data = create_empty_data(missing_rows) return pd.concat([empty_data, data]).reset_index(drop=True) return data if __name__ == "__main__": # 测试函数 result = find_stock_entry('AAPL') print(f"find_stock_entry: {result}") result = get_stock_history('AAPL', '20240214') print(f"get_stock_history: {result}") result = get_stock_index_history('AAPL', '20240214') print(f"get_stock_index_history: {result}") result = find_stock_codes_or_names([('苹果', 'ORG'), ('苹果公司', 'ORG')]) print(f"find_stock_codes_or_names: {result}") result = process_history(get_stock_history('AAPL', '20240214'), '20240214') print(f"process_history: {result}") result = process_history(get_stock_index_history('AAPL', '20240214'), '20240214') print(f"process_history: {result}") pass