recrysss commited on
Commit
f4f80fa
·
verified ·
1 Parent(s): 4cf40c0

Upload 12 files

Browse files
Dockerfile ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ COPY ./app /app/app
6
+ COPY requirements.txt .
7
+
8
+ RUN pip install --no-cache-dir -r requirements.txt
9
+
10
+ # 环境变量 (在 Hugging Face Spaces 中设置)
11
+ # ENV GEMINI_API_KEYS=your_key_1,your_key_2,your_key_3
12
+
13
+ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
app/__init__.py ADDED
File without changes
app/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (116 Bytes). View file
 
app/__pycache__/gemini.cpython-39.pyc ADDED
Binary file (3.02 kB). View file
 
app/__pycache__/main.cpython-39.pyc ADDED
Binary file (7.65 kB). View file
 
app/__pycache__/models.cpython-39.pyc ADDED
Binary file (2.32 kB). View file
 
app/__pycache__/utils.cpython-39.pyc ADDED
Binary file (2.42 kB). View file
 
app/gemini.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import json
3
+ import os
4
+ import asyncio
5
+ from app.models import ChatCompletionRequest, Message # 相对导入
6
+ from dataclasses import dataclass
7
+ from typing import Optional, Dict, Any, List
8
+ import httpx
9
+ import logging
10
+
11
+ logger = logging.getLogger('my_logger')
12
+
13
+
14
+ @dataclass
15
+ class GeneratedText:
16
+ text: str
17
+ finish_reason: Optional[str] = None
18
+
19
+
20
+ class ResponseWrapper:
21
+ def __init__(self, data: Dict[Any, Any]): # 正确的初始化方法名
22
+ self._data = data
23
+ self._text = self._extract_text()
24
+ self._finish_reason = self._extract_finish_reason()
25
+ self._prompt_token_count = self._extract_prompt_token_count()
26
+ self._candidates_token_count = self._extract_candidates_token_count()
27
+ self._total_token_count = self._extract_total_token_count()
28
+ self._thoughts = self._extract_thoughts()
29
+ self._json_dumps = json.dumps(self._data, indent=4, ensure_ascii=False)
30
+
31
+ def _extract_thoughts(self) -> Optional[str]:
32
+ try:
33
+ for part in self._data['candidates'][0]['content']['parts']:
34
+ if 'thought' in part:
35
+ return part['text']
36
+ return ""
37
+ except (KeyError, IndexError):
38
+ return ""
39
+
40
+ def _extract_text(self) -> str:
41
+ try:
42
+ for part in self._data['candidates'][0]['content']['parts']:
43
+ if 'thought' not in part:
44
+ return part['text']
45
+ return ""
46
+ except (KeyError, IndexError):
47
+ return ""
48
+
49
+ def _extract_finish_reason(self) -> Optional[str]:
50
+ try:
51
+ return self._data['candidates'][0].get('finishReason')
52
+ except (KeyError, IndexError):
53
+ return None
54
+
55
+ def _extract_prompt_token_count(self) -> Optional[int]:
56
+ try:
57
+ return self._data['usageMetadata'].get('promptTokenCount')
58
+ except (KeyError):
59
+ return None
60
+
61
+ def _extract_candidates_token_count(self) -> Optional[int]:
62
+ try:
63
+ return self._data['usageMetadata'].get('candidatesTokenCount')
64
+ except (KeyError):
65
+ return None
66
+
67
+ def _extract_total_token_count(self) -> Optional[int]:
68
+ try:
69
+ return self._data['usageMetadata'].get('totalTokenCount')
70
+ except (KeyError):
71
+ return None
72
+
73
+ @property
74
+ def text(self) -> str:
75
+ return self._text
76
+
77
+ @property
78
+ def finish_reason(self) -> Optional[str]:
79
+ return self._finish_reason
80
+
81
+ @property
82
+ def prompt_token_count(self) -> Optional[int]:
83
+ return self._prompt_token_count
84
+
85
+ @property
86
+ def candidates_token_count(self) -> Optional[int]:
87
+ return self._candidates_token_count
88
+
89
+ @property
90
+ def total_token_count(self) -> Optional[int]:
91
+ return self._total_token_count
92
+
93
+ @property
94
+ def thoughts(self) -> Optional[str]:
95
+ return self._thoughts
96
+
97
+ @property
98
+ def json_dumps(self) -> str:
99
+ return self._json_dumps
100
+
101
+
102
+ class GeminiClient:
103
+
104
+ AVAILABLE_MODELS = []
105
+ EXTRA_MODELS = os.environ.get("EXTRA_MODELS", "").split(",")
106
+
107
+ def __init__(self, api_key: str):
108
+ self.api_key = api_key
109
+
110
+ async def stream_chat(self, request: ChatCompletionRequest, contents, safety_settings, system_instruction):
111
+ logger.info("流式开始 →")
112
+ api_version = "v1alpha" if "think" in request.model else "v1beta"
113
+ url = f"https://generativelanguage.googleapis.com/{api_version}/models/{request.model}:streamGenerateContent?key={self.api_key}&alt=sse"
114
+ headers = {
115
+ "Content-Type": "application/json",
116
+ }
117
+ data = {
118
+ "contents": contents,
119
+ "generationConfig": {
120
+ "temperature": request.temperature,
121
+ "maxOutputTokens": request.max_tokens,
122
+ },
123
+ "safetySettings": safety_settings,
124
+ }
125
+ if system_instruction:
126
+ data["system_instruction"] = system_instruction
127
+
128
+ async with httpx.AsyncClient() as client:
129
+ async with client.stream("POST", url, headers=headers, json=data, timeout=600) as response:
130
+ buffer = b""
131
+ try:
132
+ async for line in response.aiter_lines():
133
+ if not line.strip():
134
+ continue
135
+ if line.startswith("data: "):
136
+ line = line[len("data: "):]
137
+ buffer += line.encode('utf-8')
138
+ try:
139
+ data = json.loads(buffer.decode('utf-8'))
140
+ buffer = b""
141
+ if 'candidates' in data and data['candidates']:
142
+ candidate = data['candidates'][0]
143
+ if 'content' in candidate:
144
+ content = candidate['content']
145
+ if 'parts' in content and content['parts']:
146
+ parts = content['parts']
147
+ text = ""
148
+ for part in parts:
149
+ if 'text' in part:
150
+ text += part['text']
151
+ if text:
152
+ yield text
153
+
154
+ if candidate.get("finishReason") and candidate.get("finishReason") != "STOP":
155
+ # logger.warning(f"模型的响应因违反内容政策而被标记: {candidate.get('finishReason')}")
156
+ raise ValueError(f"模型的响应被截断: {candidate.get('finishReason')}")
157
+
158
+ if 'safetyRatings' in candidate:
159
+ for rating in candidate['safetyRatings']:
160
+ if rating['probability'] == 'HIGH':
161
+ # logger.warning(f"模型的响应因高概率被标记为 {rating['category']}")
162
+ raise ValueError(f"模型的响应被截断: {rating['category']}")
163
+ except json.JSONDecodeError:
164
+ # logger.debug(f"JSON解析错误, 当前缓冲区内容: {buffer}")
165
+ continue
166
+ except Exception as e:
167
+ # logger.error(f"流式处理期间发生错误: {e}")
168
+ raise e
169
+ except Exception as e:
170
+ # logger.error(f"流式处理错误: {e}")
171
+ raise e
172
+ finally:
173
+ logger.info("流式结束 ←")
174
+
175
+
176
+ def complete_chat(self, request: ChatCompletionRequest, contents, safety_settings, system_instruction):
177
+ api_version = "v1alpha" if "think" in request.model else "v1beta"
178
+ url = f"https://generativelanguage.googleapis.com/{api_version}/models/{request.model}:generateContent?key={self.api_key}"
179
+ headers = {
180
+ "Content-Type": "application/json",
181
+ }
182
+ data = {
183
+ "contents": contents,
184
+ "generationConfig": {
185
+ "temperature": request.temperature,
186
+ "maxOutputTokens": request.max_tokens,
187
+ },
188
+ "safetySettings": safety_settings,
189
+ }
190
+ if system_instruction:
191
+ data["system_instruction"] = system_instruction
192
+ response = requests.post(url, headers=headers, json=data)
193
+ response.raise_for_status()
194
+ return ResponseWrapper(response.json())
195
+
196
+ def convert_messages(self, messages, use_system_prompt=False):
197
+ gemini_history = []
198
+ errors = []
199
+ system_instruction_text = ""
200
+ is_system_phase = use_system_prompt
201
+ for i, message in enumerate(messages):
202
+ role = message.role
203
+ content = message.content
204
+
205
+ if isinstance(content, str):
206
+ if is_system_phase and role == 'system':
207
+ if system_instruction_text:
208
+ system_instruction_text += "\n" + content
209
+ else:
210
+ system_instruction_text = content
211
+ else:
212
+ is_system_phase = False
213
+
214
+ if role in ['user', 'system']:
215
+ role_to_use = 'user'
216
+ elif role == 'assistant':
217
+ role_to_use = 'model'
218
+ else:
219
+ errors.append(f"Invalid role: {role}")
220
+ continue
221
+
222
+ if gemini_history and gemini_history[-1]['role'] == role_to_use:
223
+ gemini_history[-1]['parts'].append({"text": content})
224
+ else:
225
+ gemini_history.append(
226
+ {"role": role_to_use, "parts": [{"text": content}]})
227
+ elif isinstance(content, list):
228
+ parts = []
229
+ for item in content:
230
+ if item.get('type') == 'text':
231
+ parts.append({"text": item.get('text')})
232
+ elif item.get('type') == 'image_url':
233
+ image_data = item.get('image_url', {}).get('url', '')
234
+ if image_data.startswith('data:image/'):
235
+ try:
236
+ mime_type, base64_data = image_data.split(';')[0].split(':')[1], image_data.split(',')[1]
237
+ parts.append({
238
+ "inline_data": {
239
+ "mime_type": mime_type,
240
+ "data": base64_data
241
+ }
242
+ })
243
+ except (IndexError, ValueError):
244
+ errors.append(
245
+ f"Invalid data URI for image: {image_data}")
246
+ else:
247
+ errors.append(
248
+ f"Invalid image URL format for item: {item}")
249
+
250
+ if parts:
251
+ if role in ['user', 'system']:
252
+ role_to_use = 'user'
253
+ elif role == 'assistant':
254
+ role_to_use = 'model'
255
+ else:
256
+ errors.append(f"Invalid role: {role}")
257
+ continue
258
+ if gemini_history and gemini_history[-1]['role'] == role_to_use:
259
+ gemini_history[-1]['parts'].extend(parts)
260
+ else:
261
+ gemini_history.append(
262
+ {"role": role_to_use, "parts": parts})
263
+ if errors:
264
+ return errors
265
+ else:
266
+ return gemini_history, {"parts": [{"text": system_instruction_text}]}
267
+
268
+ @staticmethod
269
+ async def list_available_models(api_key) -> list:
270
+ url = "https://generativelanguage.googleapis.com/v1beta/models?key={}".format(
271
+ api_key)
272
+ async with httpx.AsyncClient() as client:
273
+ response = await client.get(url)
274
+ response.raise_for_status()
275
+ data = response.json()
276
+ models = [model["name"] for model in data.get("models", [])]
277
+ models.extend(GeminiClient.EXTRA_MODELS)
278
+ return models
app/main.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, Request, Depends, status
2
+ from fastapi.responses import JSONResponse, StreamingResponse, HTMLResponse
3
+ from .models import ChatCompletionRequest, ChatCompletionResponse, ErrorResponse, ModelList
4
+ from .gemini import GeminiClient, ResponseWrapper
5
+ from .utils import handle_gemini_error, protect_from_abuse, APIKeyManager, test_api_key, format_log_message
6
+ import os
7
+ import json
8
+ import asyncio
9
+ from typing import Literal
10
+ import random
11
+ import requests
12
+ from datetime import datetime, timedelta
13
+ from apscheduler.schedulers.background import BackgroundScheduler
14
+ import sys
15
+ import logging
16
+
17
+ logging.getLogger("uvicorn").disabled = True
18
+ logging.getLogger("uvicorn.access").disabled = True
19
+
20
+ # 配置 logger
21
+ logger = logging.getLogger("my_logger")
22
+ logger.setLevel(logging.DEBUG)
23
+
24
+ def translate_error(message: str) -> str:
25
+ if "quota exceeded" in message.lower():
26
+ return "API 密钥配额已用尽"
27
+ if "invalid argument" in message.lower():
28
+ return "无效参数"
29
+ if "internal server error" in message.lower():
30
+ return "服务器内部错误"
31
+ if "service unavailable" in message.lower():
32
+ return "服务不可用"
33
+ return message
34
+
35
+
36
+ def handle_exception(exc_type, exc_value, exc_traceback):
37
+ if issubclass(exc_type, KeyboardInterrupt):
38
+ sys.excepthook(exc_type, exc_value, exc_traceback)
39
+ return
40
+ error_message = translate_error(str(exc_value))
41
+ log_msg = format_log_message('ERROR', f"未捕获的异常: %s" % error_message, extra={'status_code': 500, 'error_message': error_message})
42
+ logger.error(log_msg)
43
+
44
+
45
+ sys.excepthook = handle_exception
46
+
47
+ app = FastAPI()
48
+
49
+ PASSWORD = os.environ.get("PASSWORD", "123")
50
+ MAX_REQUESTS_PER_MINUTE = int(os.environ.get("MAX_REQUESTS_PER_MINUTE", "30"))
51
+ MAX_REQUESTS_PER_DAY_PER_IP = int(
52
+ os.environ.get("MAX_REQUESTS_PER_DAY_PER_IP", "600"))
53
+ # MAX_RETRIES = int(os.environ.get('MaxRetries', '3').strip() or '3')
54
+ RETRY_DELAY = 1
55
+ MAX_RETRY_DELAY = 16
56
+ safety_settings = [
57
+ {
58
+ "category": "HARM_CATEGORY_HARASSMENT",
59
+ "threshold": "BLOCK_NONE"
60
+ },
61
+ {
62
+ "category": "HARM_CATEGORY_HATE_SPEECH",
63
+ "threshold": "BLOCK_NONE"
64
+ },
65
+ {
66
+ "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
67
+ "threshold": "BLOCK_NONE"
68
+ },
69
+ {
70
+ "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
71
+ "threshold": "BLOCK_NONE"
72
+ },
73
+ {
74
+ "category": 'HARM_CATEGORY_CIVIC_INTEGRITY',
75
+ "threshold": 'BLOCK_NONE'
76
+ }
77
+ ]
78
+ safety_settings_g2 = [
79
+ {
80
+ "category": "HARM_CATEGORY_HARASSMENT",
81
+ "threshold": "OFF"
82
+ },
83
+ {
84
+ "category": "HARM_CATEGORY_HATE_SPEECH",
85
+ "threshold": "OFF"
86
+ },
87
+ {
88
+ "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
89
+ "threshold": "OFF"
90
+ },
91
+ {
92
+ "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
93
+ "threshold": "OFF"
94
+ },
95
+ {
96
+ "category": 'HARM_CATEGORY_CIVIC_INTEGRITY',
97
+ "threshold": 'OFF'
98
+ }
99
+ ]
100
+
101
+ key_manager = APIKeyManager() # 实例化 APIKeyManager,栈会在 __init__ 中初始化
102
+ current_api_key = key_manager.get_available_key()
103
+
104
+
105
+ def switch_api_key():
106
+ global current_api_key
107
+ key = key_manager.get_available_key() # get_available_key 会处理栈的逻辑
108
+ if key:
109
+ current_api_key = key
110
+ log_msg = format_log_message('INFO', f"API key 替换为 → {current_api_key[:8]}...", extra={'key': current_api_key[:8], 'request_type': 'switch_key'})
111
+ logger.info(log_msg)
112
+ else:
113
+ log_msg = format_log_message('ERROR', "API key 替换失败,所有API key都已尝试,请重新配置或稍后重试", extra={'key': 'N/A', 'request_type': 'switch_key', 'status_code': 'N/A'})
114
+ logger.error(log_msg)
115
+
116
+
117
+ async def check_keys():
118
+ available_keys = []
119
+ for key in key_manager.api_keys:
120
+ is_valid = await test_api_key(key)
121
+ status_msg = "有效" if is_valid else "无效"
122
+ log_msg = format_log_message('INFO', f"API Key {key[:10]}... {status_msg}.")
123
+ logger.info(log_msg)
124
+ if is_valid:
125
+ available_keys.append(key)
126
+ if not available_keys:
127
+ log_msg = format_log_message('ERROR', "没有可用的 API 密钥!", extra={'key': 'N/A', 'request_type': 'startup', 'status_code': 'N/A'})
128
+ logger.error(log_msg)
129
+ return available_keys
130
+
131
+
132
+ @app.on_event("startup")
133
+ async def startup_event():
134
+ log_msg = format_log_message('INFO', "Starting Gemini API proxy...")
135
+ logger.info(log_msg)
136
+ available_keys = await check_keys()
137
+ if available_keys:
138
+ key_manager.api_keys = available_keys
139
+ key_manager._reset_key_stack() # 启动时也确保创建随机栈
140
+ key_manager.show_all_keys()
141
+ log_msg = format_log_message('INFO', f"可用 API 密钥数量:{len(key_manager.api_keys)}")
142
+ logger.info(log_msg)
143
+ # MAX_RETRIES = len(key_manager.api_keys)
144
+ log_msg = format_log_message('INFO', f"最大重试次数设置为:{len(key_manager.api_keys)}") # 添加日志
145
+ logger.info(log_msg)
146
+ if key_manager.api_keys:
147
+ all_models = await GeminiClient.list_available_models(key_manager.api_keys[0])
148
+ GeminiClient.AVAILABLE_MODELS = [model.replace(
149
+ "models/", "") for model in all_models]
150
+ log_msg = format_log_message('INFO', "Available models loaded.")
151
+ logger.info(log_msg)
152
+
153
+ @app.get("/v1/models", response_model=ModelList)
154
+ def list_models():
155
+ log_msg = format_log_message('INFO', "Received request to list models", extra={'request_type': 'list_models', 'status_code': 200})
156
+ logger.info(log_msg)
157
+ return ModelList(data=[{"id": model, "object": "model", "created": 1678888888, "owned_by": "organization-owner"} for model in GeminiClient.AVAILABLE_MODELS])
158
+
159
+
160
+ async def verify_password(request: Request):
161
+ if PASSWORD:
162
+ auth_header = request.headers.get("Authorization")
163
+ if not auth_header or not auth_header.startswith("Bearer "):
164
+ raise HTTPException(
165
+ status_code=401, detail="Unauthorized: Missing or invalid token")
166
+ token = auth_header.split(" ")[1]
167
+ if token != PASSWORD:
168
+ raise HTTPException(
169
+ status_code=401, detail="Unauthorized: Invalid token")
170
+
171
+
172
+ async def process_request(chat_request: ChatCompletionRequest, http_request: Request, request_type: Literal['stream', 'non-stream']):
173
+ global current_api_key
174
+ protect_from_abuse(
175
+ http_request, MAX_REQUESTS_PER_MINUTE, MAX_REQUESTS_PER_DAY_PER_IP)
176
+ if chat_request.model not in GeminiClient.AVAILABLE_MODELS:
177
+ error_msg = "无效的模型"
178
+ extra_log = {'request_type': request_type, 'model': chat_request.model, 'status_code': 400, 'error_message': error_msg}
179
+ log_msg = format_log_message('ERROR', error_msg, extra=extra_log)
180
+ logger.error(log_msg)
181
+ raise HTTPException(
182
+ status_code=status.HTTP_400_BAD_REQUEST, detail=error_msg)
183
+
184
+ key_manager.reset_tried_keys_for_request() # 在每次请求处理开始时重置 tried_keys 集合
185
+
186
+ contents, system_instruction = GeminiClient.convert_messages(
187
+ GeminiClient, chat_request.messages)
188
+
189
+ retry_attempts = len(key_manager.api_keys) if key_manager.api_keys else 1 # 重试次数等于密钥数量,至少尝试 1 次
190
+ for attempt in range(1, retry_attempts + 1):
191
+ if attempt == 1:
192
+ current_api_key = key_manager.get_available_key() # 每次循环开始都获取新的 key, 栈逻辑在 get_available_key 中处理
193
+
194
+ if current_api_key is None: # 检查是否获取到 API 密钥
195
+ log_msg_no_key = format_log_message('WARNING', "没有可用的 API 密钥,跳过本次尝试", extra={'request_type': request_type, 'model': chat_request.model, 'status_code': 'N/A'})
196
+ logger.warning(log_msg_no_key)
197
+ break # 如果没有可用密钥,跳出循环
198
+
199
+ extra_log = {'key': current_api_key[:8], 'request_type': request_type, 'model': chat_request.model, 'status_code': 'N/A', 'error_message': ''}
200
+ log_msg = format_log_message('INFO', f"第 {attempt}/{retry_attempts} 次尝试 ... 使用密钥: {current_api_key[:8]}...", extra=extra_log)
201
+ logger.info(log_msg)
202
+
203
+ gemini_client = GeminiClient(current_api_key)
204
+ try:
205
+ if chat_request.stream:
206
+ async def stream_generator():
207
+ try:
208
+ async for chunk in gemini_client.stream_chat(chat_request, contents, safety_settings_g2 if 'gemini-2.0-flash-exp' in chat_request.model else safety_settings, system_instruction):
209
+ formatted_chunk = {"id": "chatcmpl-someid", "object": "chat.completion.chunk", "created": 1234567,
210
+ "model": chat_request.model, "choices": [{"delta": {"role": "assistant", "content": chunk}, "index": 0, "finish_reason": None}]}
211
+ yield f"data: {json.dumps(formatted_chunk)}\n\n"
212
+ yield "data: [DONE]\n\n"
213
+
214
+ except asyncio.CancelledError:
215
+ extra_log_cancel = {'key': current_api_key[:8], 'request_type': request_type, 'model': chat_request.model, 'error_message': '客户端已断开连接'}
216
+ log_msg = format_log_message('INFO', "客户端连接已中断", extra=extra_log_cancel)
217
+ logger.info(log_msg)
218
+ except Exception as e:
219
+ error_detail = handle_gemini_error(
220
+ e, current_api_key, key_manager)
221
+ yield f"data: {json.dumps({'error': {'message': error_detail, 'type': 'gemini_error'}})}\n\n"
222
+ return StreamingResponse(stream_generator(), media_type="text/event-stream")
223
+ else:
224
+ async def run_gemini_completion():
225
+ try:
226
+ response_content = await asyncio.to_thread(gemini_client.complete_chat, chat_request, contents, safety_settings_g2 if 'gemini-2.0-flash-exp' in chat_request.model else safety_settings, system_instruction)
227
+ return response_content
228
+ except asyncio.CancelledError:
229
+ extra_log_gemini_cancel = {'key': current_api_key[:8], 'request_type': request_type, 'model': chat_request.model, 'error_message': '客户端断开导致API调用取消'}
230
+ log_msg = format_log_message('INFO', "API调用因客户端断开而取消", extra=extra_log_gemini_cancel)
231
+ logger.info(log_msg)
232
+ raise
233
+
234
+ async def check_client_disconnect():
235
+ while True:
236
+ if await http_request.is_disconnected():
237
+ extra_log_client_disconnect = {'key': current_api_key[:8], 'request_type': request_type, 'model': chat_request.model, 'error_message': '检测到客户端断开连接'}
238
+ log_msg = format_log_message('INFO', "客户端连接已中断,正在取消API请求", extra=extra_log_client_disconnect)
239
+ logger.info(log_msg)
240
+ return True
241
+ await asyncio.sleep(0.5)
242
+
243
+ gemini_task = asyncio.create_task(run_gemini_completion())
244
+ disconnect_task = asyncio.create_task(check_client_disconnect())
245
+
246
+ try:
247
+ done, pending = await asyncio.wait(
248
+ [gemini_task, disconnect_task],
249
+ return_when=asyncio.FIRST_COMPLETED
250
+ )
251
+
252
+ if disconnect_task in done:
253
+ gemini_task.cancel()
254
+ try:
255
+ await gemini_task
256
+ except asyncio.CancelledError:
257
+ extra_log_gemini_task_cancel = {'key': current_api_key[:8], 'request_type': request_type, 'model': chat_request.model, 'error_message': 'API任务已终止'}
258
+ log_msg = format_log_message('INFO', "API任务已成功取消", extra=extra_log_gemini_task_cancel)
259
+ logger.info(log_msg)
260
+ # 直接抛出异常中断循环
261
+ raise HTTPException(status_code=status.HTTP_408_REQUEST_TIMEOUT, detail="客户端连接已中断")
262
+
263
+ if gemini_task in done:
264
+ disconnect_task.cancel()
265
+ try:
266
+ await disconnect_task
267
+ except asyncio.CancelledError:
268
+ pass
269
+ response_content = gemini_task.result()
270
+ if response_content.text == "":
271
+ extra_log_empty_response = {'key': current_api_key[:8], 'request_type': request_type, 'model': chat_request.model, 'status_code': 204}
272
+ log_msg = format_log_message('INFO', "Gemini API 返回空响应", extra=extra_log_empty_response)
273
+ logger.info(log_msg)
274
+ # 继续循环
275
+ continue
276
+ response = ChatCompletionResponse(id="chatcmpl-someid", object="chat.completion", created=1234567890, model=chat_request.model,
277
+ choices=[{"index": 0, "message": {"role": "assistant", "content": response_content.text}, "finish_reason": "stop"}])
278
+ extra_log_success = {'key': current_api_key[:8], 'request_type': request_type, 'model': chat_request.model, 'status_code': 200}
279
+ log_msg = format_log_message('INFO', "请求处理成功", extra=extra_log_success)
280
+ logger.info(log_msg)
281
+ return response
282
+
283
+ except asyncio.CancelledError:
284
+ extra_log_request_cancel = {'key': current_api_key[:8], 'request_type': request_type, 'model': chat_request.model, 'error_message':"请求被取消" }
285
+ log_msg = format_log_message('INFO', "请求取消", extra=extra_log_request_cancel)
286
+ logger.info(log_msg)
287
+ raise
288
+
289
+ except HTTPException as e:
290
+ if e.status_code == status.HTTP_408_REQUEST_TIMEOUT:
291
+ extra_log = {'key': current_api_key[:8], 'request_type': request_type, 'model': chat_request.model,
292
+ 'status_code': 408, 'error_message': '客户端连接中断'}
293
+ log_msg = format_log_message('ERROR', "客户端连接中断,终止后续重试", extra=extra_log)
294
+ logger.error(log_msg)
295
+ raise
296
+ else:
297
+ raise
298
+ except Exception as e:
299
+ handle_gemini_error(e, current_api_key, key_manager)
300
+ if attempt < retry_attempts:
301
+ switch_api_key()
302
+ continue
303
+
304
+ msg = "所有API密钥均失败,请稍后重试"
305
+ extra_log_all_fail = {'key': "ALL", 'request_type': request_type, 'model': chat_request.model, 'status_code': 500, 'error_message': msg}
306
+ log_msg = format_log_message('ERROR', msg, extra=extra_log_all_fail)
307
+ logger.error(log_msg)
308
+ raise HTTPException(
309
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=msg)
310
+
311
+
312
+ @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
313
+ async def chat_completions(request: ChatCompletionRequest, http_request: Request, _: None = Depends(verify_password)):
314
+ return await process_request(request, http_request, "stream" if request.stream else "non-stream")
315
+
316
+
317
+ @app.exception_handler(Exception)
318
+ async def global_exception_handler(request: Request, exc: Exception):
319
+ error_message = translate_error(str(exc))
320
+ extra_log_unhandled_exception = {'status_code': 500, 'error_message': error_message}
321
+ log_msg = format_log_message('ERROR', f"Unhandled exception: {error_message}", extra=extra_log_unhandled_exception)
322
+ logger.error(log_msg)
323
+ return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=ErrorResponse(message=str(exc), type="internal_error").dict())
324
+
325
+
326
+ @app.get("/", response_class=HTMLResponse)
327
+ async def root():
328
+ html_content = f"""
329
+ <!DOCTYPE html>
330
+ <html>
331
+ <head>
332
+ <title>Gemini API 代理服务</title>
333
+ <style>
334
+ body {{
335
+ font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif;
336
+ max-width: 800px;
337
+ margin: 0 auto;
338
+ padding: 20px;
339
+ line-height: 1.6;
340
+ }}
341
+ h1 {{
342
+ color: #333;
343
+ text-align: center;
344
+ margin-bottom: 30px;
345
+ }}
346
+ .info-box {{
347
+ background-color: #f8f9fa;
348
+ border: 1px solid #dee2e6;
349
+ border-radius: 4px;
350
+ padding: 20px;
351
+ margin-bottom: 20px;
352
+ }}
353
+ .status {{
354
+ color: #28a745;
355
+ font-weight: bold;
356
+ }}
357
+ </style>
358
+ </head>
359
+ <body>
360
+ <h1>🤖 Gemini API 代理服务</h1>
361
+
362
+ <div class="info-box">
363
+ <h2>🟢 运行状态</h2>
364
+ <p class="status">服务运行中</p>
365
+ <p>可用API密钥数量: {len(key_manager.api_keys)}</p>
366
+ <p>可用模型数量: {len(GeminiClient.AVAILABLE_MODELS)}</p>
367
+ </div>
368
+
369
+ <div class="info-box">
370
+ <h2>⚙️ 环境配置</h2>
371
+ <p>每分钟请求限制: {MAX_REQUESTS_PER_MINUTE}</p>
372
+ <p>每IP每日请求限制: {MAX_REQUESTS_PER_DAY_PER_IP}</p>
373
+ <p>最大重试次数: {len(key_manager.api_keys)}</p>
374
+ </div>
375
+ </body>
376
+ </html>
377
+ """
378
+ return html_content
app/models.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Optional, Union, Literal
2
+ from pydantic import BaseModel, Field
3
+
4
+ class Message(BaseModel):
5
+ role: str
6
+ content: str
7
+
8
+ class ChatCompletionRequest(BaseModel):
9
+ model: str
10
+ messages: List[Message]
11
+ temperature: float = 0.7
12
+ top_p: Optional[float] = 1.0
13
+ n: int = 1
14
+ stream: bool = False
15
+ stop: Optional[Union[str, List[str]]] = None
16
+ max_tokens: Optional[int] = None
17
+ presence_penalty: Optional[float] = 0.0
18
+ frequency_penalty: Optional[float] = 0.0
19
+
20
+ class Choice(BaseModel):
21
+ index: int
22
+ message: Message
23
+ finish_reason: Optional[str] = None
24
+
25
+ class Usage(BaseModel):
26
+ prompt_tokens: int = 0
27
+ completion_tokens: int = 0
28
+ total_tokens: int = 0
29
+
30
+ class ChatCompletionResponse(BaseModel):
31
+ id: str
32
+ object: Literal["chat.completion"]
33
+ created: int
34
+ model: str
35
+ choices: List[Choice]
36
+ usage: Usage = Field(default_factory=Usage)
37
+
38
+ class ErrorResponse(BaseModel):
39
+ message: str
40
+ type: str
41
+ param: Optional[str] = None
42
+ code: Optional[str] = None
43
+
44
+ class ModelList(BaseModel):
45
+ object: str = "list"
46
+ data: List[Dict]
app/utils.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from fastapi import HTTPException, Request
3
+ import time
4
+ import re
5
+ from datetime import datetime, timedelta
6
+ from apscheduler.schedulers.background import BackgroundScheduler
7
+ import os
8
+ import requests
9
+ import httpx
10
+ from threading import Lock
11
+ import logging
12
+ import sys
13
+
14
+ DEBUG = os.environ.get("DEBUG", "false").lower() == "true"
15
+ LOG_FORMAT_DEBUG = '%(asctime)s - %(levelname)s - [%(key)s]-%(request_type)s-[%(model)s]-%(status_code)s: %(message)s - %(error_message)s'
16
+ LOG_FORMAT_NORMAL = '[%(key)s]-%(request_type)s-[%(model)s]-%(status_code)s: %(message)s'
17
+
18
+ # 配置 logger
19
+ logger = logging.getLogger("my_logger")
20
+ logger.setLevel(logging.DEBUG)
21
+
22
+ handler = logging.StreamHandler()
23
+ # formatter = logging.Formatter('%(message)s')
24
+ # handler.setFormatter(formatter)
25
+ logger.addHandler(handler)
26
+
27
+ def format_log_message(level, message, extra=None):
28
+ extra = extra or {}
29
+ log_values = {
30
+ 'asctime': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
31
+ 'levelname': level,
32
+ 'key': extra.get('key', 'N/A'),
33
+ 'request_type': extra.get('request_type', 'N/A'),
34
+ 'model': extra.get('model', 'N/A'),
35
+ 'status_code': extra.get('status_code', 'N/A'),
36
+ 'error_message': extra.get('error_message', ''),
37
+ 'message': message
38
+ }
39
+ log_format = LOG_FORMAT_DEBUG if DEBUG else LOG_FORMAT_NORMAL
40
+ return log_format % log_values
41
+
42
+
43
+ class APIKeyManager:
44
+ def __init__(self):
45
+ self.api_keys = re.findall(
46
+ r"AIzaSy[a-zA-Z0-9_-]{33}", os.environ.get('GEMINI_API_KEYS', ""))
47
+ self.key_stack = [] # 初始化密钥栈
48
+ self._reset_key_stack() # 初始化时创建随机密钥栈
49
+ # self.api_key_blacklist = set()
50
+ # self.api_key_blacklist_duration = 60
51
+ self.scheduler = BackgroundScheduler()
52
+ self.scheduler.start()
53
+ self.tried_keys_for_request = set() # 用于跟踪当前请求尝试中已试过的 key
54
+
55
+ def _reset_key_stack(self):
56
+ """创建并随机化密钥栈"""
57
+ shuffled_keys = self.api_keys[:] # 创建 api_keys 的副本以避免直接修改原列表
58
+ random.shuffle(shuffled_keys)
59
+ self.key_stack = shuffled_keys
60
+
61
+
62
+ def get_available_key(self):
63
+ """从栈顶获取密钥,栈空时重新生成 (修改后)"""
64
+ while self.key_stack:
65
+ key = self.key_stack.pop()
66
+ # if key not in self.api_key_blacklist and key not in self.tried_keys_for_request:
67
+ if key not in self.tried_keys_for_request:
68
+ self.tried_keys_for_request.add(key)
69
+ return key
70
+
71
+ if not self.api_keys:
72
+ log_msg = format_log_message('ERROR', "没有配置任何 API 密钥!")
73
+ logger.error(log_msg)
74
+ return None
75
+
76
+ self._reset_key_stack() # 重新生成密钥栈
77
+
78
+ # 再次尝试从新栈中获取密钥 (迭代一次)
79
+ while self.key_stack:
80
+ key = self.key_stack.pop()
81
+ # if key not in self.api_key_blacklist and key not in self.tried_keys_for_request:
82
+ if key not in self.tried_keys_for_request:
83
+ self.tried_keys_for_request.add(key)
84
+ return key
85
+
86
+ return None
87
+
88
+
89
+ def show_all_keys(self):
90
+ log_msg = format_log_message('INFO', f"当前可用API key个数: {len(self.api_keys)} ")
91
+ logger.info(log_msg)
92
+ for i, api_key in enumerate(self.api_keys):
93
+ log_msg = format_log_message('INFO', f"API Key{i}: {api_key[:8]}...{api_key[-3:]}")
94
+ logger.info(log_msg)
95
+
96
+ # def blacklist_key(self, key):
97
+ # log_msg = format_log_message('WARNING', f"{key[:8]} → 暂时禁用 {self.api_key_blacklist_duration} 秒")
98
+ # logger.warning(log_msg)
99
+ # self.api_key_blacklist.add(key)
100
+ # self.scheduler.add_job(lambda: self.api_key_blacklist.discard(key), 'date',
101
+ # run_date=datetime.now() + timedelta(seconds=self.api_key_blacklist_duration))
102
+
103
+ def reset_tried_keys_for_request(self):
104
+ """在新的请求尝试时重置已尝试的 key 集合"""
105
+ self.tried_keys_for_request = set()
106
+
107
+
108
+ def handle_gemini_error(error, current_api_key, key_manager) -> str:
109
+ if isinstance(error, requests.exceptions.HTTPError):
110
+ status_code = error.response.status_code
111
+ if status_code == 400:
112
+ try:
113
+ error_data = error.response.json()
114
+ if 'error' in error_data:
115
+ if error_data['error'].get('code') == "invalid_argument":
116
+ error_message = "无效的 API 密钥"
117
+ extra_log_invalid_key = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
118
+ log_msg = format_log_message('ERROR', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 无效,可能已过期或被删除", extra=extra_log_invalid_key)
119
+ logger.error(log_msg)
120
+ # key_manager.blacklist_key(current_api_key)
121
+
122
+ return error_message
123
+ error_message = error_data['error'].get(
124
+ 'message', 'Bad Request')
125
+ extra_log_400 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
126
+ log_msg = format_log_message('WARNING', f"400 错误请求: {error_message}", extra=extra_log_400)
127
+ logger.warning(log_msg)
128
+ return f"400 错误请求: {error_message}"
129
+ except ValueError:
130
+ error_message = "400 错误请求:响应不是有效的JSON格式"
131
+ extra_log_400_json = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
132
+ log_msg = format_log_message('WARNING', error_message, extra=extra_log_400_json)
133
+ logger.warning(log_msg)
134
+ return error_message
135
+
136
+ elif status_code == 429:
137
+ error_message = "API 密钥配额已用尽或其他原因"
138
+ extra_log_429 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
139
+ log_msg = format_log_message('WARNING', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 429 官方资源耗尽或其他原因", extra=extra_log_429)
140
+ logger.warning(log_msg)
141
+ # key_manager.blacklist_key(current_api_key)
142
+
143
+ return error_message
144
+
145
+ elif status_code == 403:
146
+ error_message = "权限被拒绝"
147
+ extra_log_403 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
148
+ log_msg = format_log_message('ERROR', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 403 权限被拒绝", extra=extra_log_403)
149
+ logger.error(log_msg)
150
+ # key_manager.blacklist_key(current_api_key)
151
+
152
+ return error_message
153
+ elif status_code == 500:
154
+ error_message = "服务器内部错误"
155
+ extra_log_500 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
156
+ log_msg = format_log_message('WARNING', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 500 服务器内部错误", extra=extra_log_500)
157
+ logger.warning(log_msg)
158
+
159
+ return "Gemini API 内部错误"
160
+
161
+ elif status_code == 503:
162
+ error_message = "服务不可用"
163
+ extra_log_503 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
164
+ log_msg = format_log_message('WARNING', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 503 服务不可用", extra=extra_log_503)
165
+ logger.warning(log_msg)
166
+
167
+ return "Gemini API 服务不可用"
168
+ else:
169
+ error_message = f"未知错误: {status_code}"
170
+ extra_log_other = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
171
+ log_msg = format_log_message('WARNING', f"{current_api_key[:8]} ... {current_api_key[-3:]} → {status_code} 未知错误", extra=extra_log_other)
172
+ logger.warning(log_msg)
173
+
174
+ return f"未知错误/模型不可用: {status_code}"
175
+
176
+ elif isinstance(error, requests.exceptions.ConnectionError):
177
+ error_message = "连接错误"
178
+ log_msg = format_log_message('WARNING', error_message, extra={'error_message': error_message})
179
+ logger.warning(log_msg)
180
+ return error_message
181
+
182
+ elif isinstance(error, requests.exceptions.Timeout):
183
+ error_message = "请求超时"
184
+ log_msg = format_log_message('WARNING', error_message, extra={'error_message': error_message})
185
+ logger.warning(log_msg)
186
+ return error_message
187
+ else:
188
+ error_message = f"发生未知错误: {error}"
189
+ log_msg = format_log_message('ERROR', error_message, extra={'error_message': error_message})
190
+ logger.error(log_msg)
191
+ return error_message
192
+
193
+
194
+ async def test_api_key(api_key: str) -> bool:
195
+ """
196
+ 测试 API 密钥是否有效。
197
+ """
198
+ try:
199
+ url = "https://generativelanguage.googleapis.com/v1beta/models?key={}".format(api_key)
200
+ async with httpx.AsyncClient() as client:
201
+ response = await client.get(url)
202
+ response.raise_for_status()
203
+ return True
204
+ except Exception:
205
+ return False
206
+
207
+
208
+ rate_limit_data = {}
209
+ rate_limit_lock = Lock()
210
+
211
+
212
+ def protect_from_abuse(request: Request, max_requests_per_minute: int = 30, max_requests_per_day_per_ip: int = 600):
213
+ now = int(time.time())
214
+ minute = now // 60
215
+ day = now // (60 * 60 * 24)
216
+
217
+ minute_key = f"{request.url.path}:{minute}"
218
+ day_key = f"{request.client.host}:{day}"
219
+
220
+ with rate_limit_lock:
221
+ minute_count, minute_timestamp = rate_limit_data.get(
222
+ minute_key, (0, now))
223
+ if now - minute_timestamp >= 60:
224
+ minute_count = 0
225
+ minute_timestamp = now
226
+ minute_count += 1
227
+ rate_limit_data[minute_key] = (minute_count, minute_timestamp)
228
+
229
+ day_count, day_timestamp = rate_limit_data.get(day_key, (0, now))
230
+ if now - day_timestamp >= 86400:
231
+ day_count = 0
232
+ day_timestamp = now
233
+ day_count += 1
234
+ rate_limit_data[day_key] = (day_count, day_timestamp)
235
+
236
+ if minute_count > max_requests_per_minute:
237
+ raise HTTPException(status_code=429, detail={
238
+ "message": "Too many requests per minute", "limit": max_requests_per_minute})
239
+ if day_count > max_requests_per_day_per_ip:
240
+ raise HTTPException(status_code=429, detail={"message": "Too many requests per day from this IP", "limit": max_requests_per_day_per_ip})
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ httpx
4
+ python-dotenv
5
+ requests
6
+ apscheduler