Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
sachin
commited on
Commit
·
844386f
1
Parent(s):
643e32f
bearer-auth
Browse files- docs/menv.md +2 -2
- requirements.txt +2 -1
- src/server/main.py +39 -71
- src/server/utils/auth.py +84 -11
docs/menv.md
CHANGED
@@ -4,6 +4,6 @@ export SPEECH_RATE_LIMIT=5/minute
|
|
4 |
export CHAT_RATE_LIMIT=100/minute
|
5 |
export EXTERNAL_TTS_URL=https://slabstech-dhwani-internal-api-server.hf.space/v1/audio/speech
|
6 |
export EXTERNAL_ASR_URL=https://gaganyatri-asr-indic-server-cpu.hf.space
|
7 |
-
export EXTERNAL_TEXT_GEN_URL=https://
|
8 |
-
export EXTERNAL_AUDIO_PROC_URL=https://
|
9 |
export API_KEY_SECRET=your_secret_key
|
|
|
4 |
export CHAT_RATE_LIMIT=100/minute
|
5 |
export EXTERNAL_TTS_URL=https://slabstech-dhwani-internal-api-server.hf.space/v1/audio/speech
|
6 |
export EXTERNAL_ASR_URL=https://gaganyatri-asr-indic-server-cpu.hf.space
|
7 |
+
export EXTERNAL_TEXT_GEN_URL=https://slabstech-dhwani-internal-api-server.hf.space
|
8 |
+
export EXTERNAL_AUDIO_PROC_URL=https://slabstech-dhwani-internal-api-server.hf.space
|
9 |
export API_KEY_SECRET=your_secret_key
|
requirements.txt
CHANGED
@@ -4,4 +4,5 @@ pydantic_settings
|
|
4 |
slowapi
|
5 |
requests
|
6 |
python-multipart
|
7 |
-
pillow
|
|
|
|
4 |
slowapi
|
5 |
requests
|
6 |
python-multipart
|
7 |
+
pillow
|
8 |
+
pyjwt
|
src/server/main.py
CHANGED
@@ -9,40 +9,17 @@ from fastapi import Depends, FastAPI, File, HTTPException, Query, Request, Uploa
|
|
9 |
from fastapi.middleware.cors import CORSMiddleware
|
10 |
from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
|
11 |
from pydantic import BaseModel, Field, field_validator
|
12 |
-
from pydantic_settings import BaseSettings
|
13 |
from slowapi import Limiter
|
14 |
from slowapi.util import get_remote_address
|
15 |
import requests
|
16 |
from PIL import Image
|
17 |
|
|
|
|
|
|
|
18 |
# Assuming these are in your project structure
|
19 |
from config.tts_config import SPEED, ResponseFormat, config as tts_config
|
20 |
from config.logging_config import logger
|
21 |
-
#from utils.auth import get_api_key
|
22 |
-
|
23 |
-
# Configuration settings
|
24 |
-
class Settings(BaseSettings):
|
25 |
-
llm_model_name: str = "google/gemma-3-4b-it"
|
26 |
-
max_tokens: int = 512
|
27 |
-
host: str = "0.0.0.0"
|
28 |
-
port: int = 7860
|
29 |
-
chat_rate_limit: str = "100/minute"
|
30 |
-
speech_rate_limit: str = "5/minute"
|
31 |
-
external_tts_url: str = Field(..., env="EXTERNAL_TTS_URL")
|
32 |
-
external_asr_url: str = Field(..., env="EXTERNAL_ASR_URL")
|
33 |
-
external_text_gen_url: str = Field(..., env="EXTERNAL_TEXT_GEN_URL")
|
34 |
-
external_audio_proc_url: str = Field(..., env="EXTERNAL_AUDIO_PROC_URL")
|
35 |
-
api_key_secret: str = Field(..., env="API_KEY_SECRET")
|
36 |
-
|
37 |
-
@field_validator("chat_rate_limit", "speech_rate_limit")
|
38 |
-
def validate_rate_limit(cls, v):
|
39 |
-
if not v.count("/") == 1 or not v.split("/")[0].isdigit():
|
40 |
-
raise ValueError("Rate limit must be in format 'number/period' (e.g., '5/minute')")
|
41 |
-
return v
|
42 |
-
|
43 |
-
class Config:
|
44 |
-
env_file = ".env"
|
45 |
-
env_file_encoding = "utf-8"
|
46 |
|
47 |
settings = Settings()
|
48 |
|
@@ -127,12 +104,17 @@ async def health_check():
|
|
127 |
async def home():
|
128 |
return RedirectResponse(url="/docs")
|
129 |
|
|
|
|
|
|
|
|
|
|
|
130 |
@app.post("/v1/audio/speech")
|
131 |
@limiter.limit(settings.speech_rate_limit)
|
132 |
async def generate_audio(
|
133 |
request: Request,
|
134 |
speech_request: SpeechRequest = Depends(),
|
135 |
-
|
136 |
tts_service: TTSService = Depends(get_tts_service)
|
137 |
):
|
138 |
if not speech_request.input.strip():
|
@@ -141,7 +123,8 @@ async def generate_audio(
|
|
141 |
logger.info("Processing speech request", extra={
|
142 |
"endpoint": "/v1/audio/speech",
|
143 |
"input_length": len(speech_request.input),
|
144 |
-
"client_ip": get_remote_address(request)
|
|
|
145 |
})
|
146 |
|
147 |
payload = {
|
@@ -167,10 +150,9 @@ async def generate_audio(
|
|
167 |
headers=headers
|
168 |
)
|
169 |
|
170 |
-
|
171 |
class ChatRequest(BaseModel):
|
172 |
prompt: str
|
173 |
-
src_lang: str = "kan_Knda"
|
174 |
|
175 |
@field_validator("prompt")
|
176 |
def prompt_must_be_valid(cls, v):
|
@@ -181,22 +163,23 @@ class ChatRequest(BaseModel):
|
|
181 |
class ChatResponse(BaseModel):
|
182 |
response: str
|
183 |
|
184 |
-
|
185 |
@app.post("/v1/chat", response_model=ChatResponse)
|
186 |
@limiter.limit(settings.chat_rate_limit)
|
187 |
-
async def chat(
|
|
|
|
|
|
|
|
|
188 |
if not chat_request.prompt:
|
189 |
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
|
190 |
-
logger.info(f"Received prompt: {chat_request.prompt}, src_lang: {chat_request.src_lang}")
|
191 |
|
192 |
try:
|
193 |
-
|
194 |
-
# Call the external API instead of llm_manager.generate
|
195 |
external_url = "https://slabstech-dhwani-internal-api-server.hf.space/v1/chat"
|
196 |
payload = {
|
197 |
-
"prompt": chat_request.prompt
|
198 |
-
"src_lang": chat_request.src_lang,
|
199 |
-
"tgt_lang"
|
200 |
}
|
201 |
|
202 |
response = requests.post(
|
@@ -208,14 +191,12 @@ async def chat(request: Request, chat_request: ChatRequest):
|
|
208 |
},
|
209 |
timeout=60
|
210 |
)
|
211 |
-
response.raise_for_status()
|
212 |
|
213 |
-
# Extract the response text from the API
|
214 |
response_data = response.json()
|
215 |
-
|
216 |
-
logger.info(f"Generated Chat response from external API: {
|
217 |
-
|
218 |
-
return ChatResponse(response=response)
|
219 |
|
220 |
except requests.Timeout:
|
221 |
logger.error("External chat API request timed out")
|
@@ -232,13 +213,14 @@ async def chat(request: Request, chat_request: ChatRequest):
|
|
232 |
async def process_audio(
|
233 |
file: UploadFile = File(...),
|
234 |
language: str = Query(..., enum=["kannada", "hindi", "tamil"]),
|
235 |
-
|
236 |
request: Request = None,
|
237 |
):
|
238 |
logger.info("Processing audio processing request", extra={
|
239 |
"endpoint": "/v1/process_audio",
|
240 |
"filename": file.filename,
|
241 |
-
"client_ip": get_remote_address(request)
|
|
|
242 |
})
|
243 |
|
244 |
start_time = time()
|
@@ -269,16 +251,9 @@ async def process_audio(
|
|
269 |
async def transcribe_audio(
|
270 |
file: UploadFile = File(...),
|
271 |
language: str = Query(..., enum=["kannada", "hindi", "tamil"]),
|
272 |
-
|
273 |
request: Request = None,
|
274 |
):
|
275 |
-
'''
|
276 |
-
logger.info("Processing transcription request", extra={
|
277 |
-
"endpoint": "/v1/transcribe",
|
278 |
-
"filename": file.filename,
|
279 |
-
"client_ip": get_remote_address(request)
|
280 |
-
})
|
281 |
-
'''
|
282 |
start_time = time()
|
283 |
try:
|
284 |
file_content = await file.read()
|
@@ -294,13 +269,11 @@ async def transcribe_audio(
|
|
294 |
response.raise_for_status()
|
295 |
|
296 |
transcription = response.json().get("text", "")
|
297 |
-
#logger.info(f"Transcription completed in {time() - start_time:.2f} seconds")
|
298 |
return TranscriptionResponse(text=transcription)
|
299 |
|
300 |
except requests.Timeout:
|
301 |
raise HTTPException(status_code=504, detail="Transcription service timeout")
|
302 |
except requests.RequestException as e:
|
303 |
-
#logger.error(f"Transcription request failed: {str(e)}")
|
304 |
raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
|
305 |
|
306 |
@app.post("/v1/chat_v2", response_model=TranscriptionResponse)
|
@@ -309,7 +282,7 @@ async def chat_v2(
|
|
309 |
request: Request,
|
310 |
prompt: str = Form(...),
|
311 |
image: UploadFile = File(default=None),
|
312 |
-
|
313 |
):
|
314 |
if not prompt:
|
315 |
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
|
@@ -318,18 +291,18 @@ async def chat_v2(
|
|
318 |
"endpoint": "/v1/chat_v2",
|
319 |
"prompt_length": len(prompt),
|
320 |
"has_image": bool(image),
|
321 |
-
"client_ip": get_remote_address(request)
|
|
|
322 |
})
|
323 |
|
324 |
try:
|
325 |
-
# For demonstration, we'll just return the prompt as text
|
326 |
image_data = Image.open(await image.read()) if image else None
|
327 |
response_text = f"Processed: {prompt}" + (" with image" if image_data else "")
|
328 |
return TranscriptionResponse(text=response_text)
|
329 |
except Exception as e:
|
330 |
logger.error(f"Chat_v2 processing failed: {str(e)}", exc_info=True)
|
331 |
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
|
332 |
-
|
333 |
class TranslationRequest(BaseModel):
|
334 |
sentences: list[str]
|
335 |
src_lang: str
|
@@ -339,13 +312,14 @@ class TranslationResponse(BaseModel):
|
|
339 |
translations: list[str]
|
340 |
|
341 |
@app.post("/v1/translate", response_model=TranslationResponse)
|
342 |
-
async def translate(
|
343 |
-
|
|
|
|
|
|
|
344 |
|
345 |
-
# External API endpoint
|
346 |
external_url = f"https://slabstech-dhwani-internal-api-server.hf.space/translate?src_lang={request.src_lang}&tgt_lang={request.tgt_lang}"
|
347 |
|
348 |
-
# Prepare the payload matching the external API's expected format
|
349 |
payload = {
|
350 |
"sentences": request.sentences,
|
351 |
"src_lang": request.src_lang,
|
@@ -353,7 +327,6 @@ async def translate(request: TranslationRequest):
|
|
353 |
}
|
354 |
|
355 |
try:
|
356 |
-
# Make the POST request to the external API
|
357 |
response = requests.post(
|
358 |
external_url,
|
359 |
json=payload,
|
@@ -361,13 +334,10 @@ async def translate(request: TranslationRequest):
|
|
361 |
"accept": "application/json",
|
362 |
"Content-Type": "application/json"
|
363 |
},
|
364 |
-
timeout=60
|
365 |
)
|
366 |
-
|
367 |
-
# Raise an exception for bad status codes (4xx, 5xx)
|
368 |
response.raise_for_status()
|
369 |
|
370 |
-
# Extract translations from the response
|
371 |
response_data = response.json()
|
372 |
translations = response_data.get("translations", [])
|
373 |
|
@@ -388,8 +358,6 @@ async def translate(request: TranslationRequest):
|
|
388 |
logger.error(f"Invalid JSON response: {str(e)}")
|
389 |
raise HTTPException(status_code=500, detail="Invalid response format from translation service")
|
390 |
|
391 |
-
|
392 |
-
|
393 |
if __name__ == "__main__":
|
394 |
parser = argparse.ArgumentParser(description="Run the FastAPI server.")
|
395 |
parser.add_argument("--port", type=int, default=settings.port, help="Port to run the server on.")
|
|
|
9 |
from fastapi.middleware.cors import CORSMiddleware
|
10 |
from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
|
11 |
from pydantic import BaseModel, Field, field_validator
|
|
|
12 |
from slowapi import Limiter
|
13 |
from slowapi.util import get_remote_address
|
14 |
import requests
|
15 |
from PIL import Image
|
16 |
|
17 |
+
# Import from auth.py
|
18 |
+
from utils.auth import get_current_user, login, TokenResponse, Settings
|
19 |
+
|
20 |
# Assuming these are in your project structure
|
21 |
from config.tts_config import SPEED, ResponseFormat, config as tts_config
|
22 |
from config.logging_config import logger
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
settings = Settings()
|
25 |
|
|
|
104 |
async def home():
|
105 |
return RedirectResponse(url="/docs")
|
106 |
|
107 |
+
@app.post("/v1/token", response_model=TokenResponse)
|
108 |
+
async def token(user_id: str = Form(...)):
|
109 |
+
# In production, add proper authentication (e.g., password validation)
|
110 |
+
return await login(user_id=user_id)
|
111 |
+
|
112 |
@app.post("/v1/audio/speech")
|
113 |
@limiter.limit(settings.speech_rate_limit)
|
114 |
async def generate_audio(
|
115 |
request: Request,
|
116 |
speech_request: SpeechRequest = Depends(),
|
117 |
+
user_id: str = Depends(get_current_user),
|
118 |
tts_service: TTSService = Depends(get_tts_service)
|
119 |
):
|
120 |
if not speech_request.input.strip():
|
|
|
123 |
logger.info("Processing speech request", extra={
|
124 |
"endpoint": "/v1/audio/speech",
|
125 |
"input_length": len(speech_request.input),
|
126 |
+
"client_ip": get_remote_address(request),
|
127 |
+
"user_id": user_id
|
128 |
})
|
129 |
|
130 |
payload = {
|
|
|
150 |
headers=headers
|
151 |
)
|
152 |
|
|
|
153 |
class ChatRequest(BaseModel):
|
154 |
prompt: str
|
155 |
+
src_lang: str = "kan_Knda"
|
156 |
|
157 |
@field_validator("prompt")
|
158 |
def prompt_must_be_valid(cls, v):
|
|
|
163 |
class ChatResponse(BaseModel):
|
164 |
response: str
|
165 |
|
|
|
166 |
@app.post("/v1/chat", response_model=ChatResponse)
|
167 |
@limiter.limit(settings.chat_rate_limit)
|
168 |
+
async def chat(
|
169 |
+
request: Request,
|
170 |
+
chat_request: ChatRequest,
|
171 |
+
user_id: str = Depends(get_current_user)
|
172 |
+
):
|
173 |
if not chat_request.prompt:
|
174 |
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
|
175 |
+
logger.info(f"Received prompt: {chat_request.prompt}, src_lang: {chat_request.src_lang}, user_id: {user_id}")
|
176 |
|
177 |
try:
|
|
|
|
|
178 |
external_url = "https://slabstech-dhwani-internal-api-server.hf.space/v1/chat"
|
179 |
payload = {
|
180 |
+
"prompt": chat_request.prompt,
|
181 |
+
"src_lang": chat_request.src_lang,
|
182 |
+
"tgt_lang": chat_request.src_lang
|
183 |
}
|
184 |
|
185 |
response = requests.post(
|
|
|
191 |
},
|
192 |
timeout=60
|
193 |
)
|
194 |
+
response.raise_for_status()
|
195 |
|
|
|
196 |
response_data = response.json()
|
197 |
+
response_text = response_data.get("response", "")
|
198 |
+
logger.info(f"Generated Chat response from external API: {response_text}")
|
199 |
+
return ChatResponse(response=response_text)
|
|
|
200 |
|
201 |
except requests.Timeout:
|
202 |
logger.error("External chat API request timed out")
|
|
|
213 |
async def process_audio(
|
214 |
file: UploadFile = File(...),
|
215 |
language: str = Query(..., enum=["kannada", "hindi", "tamil"]),
|
216 |
+
user_id: str = Depends(get_current_user),
|
217 |
request: Request = None,
|
218 |
):
|
219 |
logger.info("Processing audio processing request", extra={
|
220 |
"endpoint": "/v1/process_audio",
|
221 |
"filename": file.filename,
|
222 |
+
"client_ip": get_remote_address(request),
|
223 |
+
"user_id": user_id
|
224 |
})
|
225 |
|
226 |
start_time = time()
|
|
|
251 |
async def transcribe_audio(
|
252 |
file: UploadFile = File(...),
|
253 |
language: str = Query(..., enum=["kannada", "hindi", "tamil"]),
|
254 |
+
user_id: str = Depends(get_current_user),
|
255 |
request: Request = None,
|
256 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
start_time = time()
|
258 |
try:
|
259 |
file_content = await file.read()
|
|
|
269 |
response.raise_for_status()
|
270 |
|
271 |
transcription = response.json().get("text", "")
|
|
|
272 |
return TranscriptionResponse(text=transcription)
|
273 |
|
274 |
except requests.Timeout:
|
275 |
raise HTTPException(status_code=504, detail="Transcription service timeout")
|
276 |
except requests.RequestException as e:
|
|
|
277 |
raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
|
278 |
|
279 |
@app.post("/v1/chat_v2", response_model=TranscriptionResponse)
|
|
|
282 |
request: Request,
|
283 |
prompt: str = Form(...),
|
284 |
image: UploadFile = File(default=None),
|
285 |
+
user_id: str = Depends(get_current_user)
|
286 |
):
|
287 |
if not prompt:
|
288 |
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
|
|
|
291 |
"endpoint": "/v1/chat_v2",
|
292 |
"prompt_length": len(prompt),
|
293 |
"has_image": bool(image),
|
294 |
+
"client_ip": get_remote_address(request),
|
295 |
+
"user_id": user_id
|
296 |
})
|
297 |
|
298 |
try:
|
|
|
299 |
image_data = Image.open(await image.read()) if image else None
|
300 |
response_text = f"Processed: {prompt}" + (" with image" if image_data else "")
|
301 |
return TranscriptionResponse(text=response_text)
|
302 |
except Exception as e:
|
303 |
logger.error(f"Chat_v2 processing failed: {str(e)}", exc_info=True)
|
304 |
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
|
305 |
+
|
306 |
class TranslationRequest(BaseModel):
|
307 |
sentences: list[str]
|
308 |
src_lang: str
|
|
|
312 |
translations: list[str]
|
313 |
|
314 |
@app.post("/v1/translate", response_model=TranslationResponse)
|
315 |
+
async def translate(
|
316 |
+
request: TranslationRequest,
|
317 |
+
user_id: str = Depends(get_current_user)
|
318 |
+
):
|
319 |
+
logger.info(f"Received translation request: {request.dict()}, user_id: {user_id}")
|
320 |
|
|
|
321 |
external_url = f"https://slabstech-dhwani-internal-api-server.hf.space/translate?src_lang={request.src_lang}&tgt_lang={request.tgt_lang}"
|
322 |
|
|
|
323 |
payload = {
|
324 |
"sentences": request.sentences,
|
325 |
"src_lang": request.src_lang,
|
|
|
327 |
}
|
328 |
|
329 |
try:
|
|
|
330 |
response = requests.post(
|
331 |
external_url,
|
332 |
json=payload,
|
|
|
334 |
"accept": "application/json",
|
335 |
"Content-Type": "application/json"
|
336 |
},
|
337 |
+
timeout=60
|
338 |
)
|
|
|
|
|
339 |
response.raise_for_status()
|
340 |
|
|
|
341 |
response_data = response.json()
|
342 |
translations = response_data.get("translations", [])
|
343 |
|
|
|
358 |
logger.error(f"Invalid JSON response: {str(e)}")
|
359 |
raise HTTPException(status_code=500, detail="Invalid response format from translation service")
|
360 |
|
|
|
|
|
361 |
if __name__ == "__main__":
|
362 |
parser = argparse.ArgumentParser(description="Run the FastAPI server.")
|
363 |
parser.add_argument("--port", type=int, default=settings.port, help="Port to run the server on.")
|
src/server/utils/auth.py
CHANGED
@@ -1,21 +1,94 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
2 |
from fastapi import HTTPException, status, Depends
|
|
|
3 |
from pydantic_settings import BaseSettings
|
4 |
-
from config.logging_config import logger
|
|
|
5 |
|
|
|
6 |
class Settings(BaseSettings):
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
class Config:
|
9 |
env_file = ".env"
|
|
|
10 |
|
11 |
settings = Settings()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
1 |
+
import jwt
|
2 |
+
from datetime import datetime, timedelta
|
3 |
+
from pydantic import BaseModel, Field, field_validator
|
4 |
+
|
5 |
+
from fastapi.security import OAuth2PasswordBearer
|
6 |
from fastapi import HTTPException, status, Depends
|
7 |
+
from pydantic import BaseModel
|
8 |
from pydantic_settings import BaseSettings
|
9 |
+
from config.logging_config import logger # Assuming this is available
|
10 |
+
from typing import Optional
|
11 |
|
12 |
+
# Centralized Settings class (can be moved to a separate config file later)
|
13 |
class Settings(BaseSettings):
|
14 |
+
api_key_secret: str = Field(..., env="API_KEY_SECRET") # Secret key for signing JWTs
|
15 |
+
token_expiration_minutes: int = Field(30, env="TOKEN_EXPIRATION_MINUTES") # Default to 30 minutes
|
16 |
+
llm_model_name: str = "google/gemma-3-4b-it"
|
17 |
+
max_tokens: int = 512
|
18 |
+
host: str = "0.0.0.0"
|
19 |
+
port: int = 7860
|
20 |
+
chat_rate_limit: str = "100/minute"
|
21 |
+
speech_rate_limit: str = "5/minute"
|
22 |
+
|
23 |
class Config:
|
24 |
env_file = ".env"
|
25 |
+
env_file_encoding = "utf-8"
|
26 |
|
27 |
settings = Settings()
|
28 |
+
logger.info(f"Loaded API_KEY_SECRET at startup: {settings.api_key_secret}") # Add this line
|
29 |
+
|
30 |
+
# OAuth2 scheme with Bearer token
|
31 |
+
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/v1/token")
|
32 |
+
|
33 |
+
# Model for token payload
|
34 |
+
class TokenPayload(BaseModel):
|
35 |
+
sub: str # Subject (user identifier)
|
36 |
+
exp: int # Expiration timestamp
|
37 |
+
|
38 |
+
# Model for token response
|
39 |
+
class TokenResponse(BaseModel):
|
40 |
+
access_token: str
|
41 |
+
token_type: str
|
42 |
+
|
43 |
+
async def create_access_token(user_id: str) -> str:
|
44 |
+
"""
|
45 |
+
Create a JWT access token for a given user.
|
46 |
+
"""
|
47 |
+
expire = datetime.utcnow() + timedelta(minutes=settings.token_expiration_minutes)
|
48 |
+
payload = {"sub": user_id, "exp": expire.timestamp()}
|
49 |
+
logger.info(f"Signing token with API_KEY_SECRET: {settings.api_key_secret}") # Add this line
|
50 |
+
token = jwt.encode(payload, settings.api_key_secret, algorithm="HS256")
|
51 |
+
logger.info(f"Generated access token for user: {user_id}")
|
52 |
+
return token
|
53 |
|
54 |
+
async def get_current_user(token: str = Depends(oauth2_scheme)) -> str:
|
55 |
+
"""
|
56 |
+
Validate the Bearer token and return the user ID.
|
57 |
+
"""
|
58 |
+
credentials_exception = HTTPException(
|
59 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
60 |
+
detail="Invalid authentication credentials",
|
61 |
+
headers={"WWW-Authenticate": "Bearer"},
|
62 |
+
)
|
63 |
+
|
64 |
+
try:
|
65 |
+
logger.info(f"Verifying token with API_KEY_SECRET: {settings.api_key_secret}") # Add this line
|
66 |
+
payload = jwt.decode(token, settings.api_key_secret, algorithms=["HS256"])
|
67 |
+
token_data = TokenPayload(**payload)
|
68 |
+
user_id = token_data.sub
|
69 |
+
if user_id is None:
|
70 |
+
raise credentials_exception
|
71 |
+
if datetime.utcnow().timestamp() > token_data.exp:
|
72 |
+
raise HTTPException(
|
73 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
74 |
+
detail="Token has expired",
|
75 |
+
headers={"WWW-Authenticate": "Bearer"},
|
76 |
+
)
|
77 |
+
logger.info(f"Validated token for user: {user_id}")
|
78 |
+
return user_id
|
79 |
+
except jwt.InvalidTokenError:
|
80 |
+
logger.warning(f"Invalid token attempt: {token[:10]}...")
|
81 |
+
raise credentials_exception
|
82 |
+
except Exception as e:
|
83 |
+
logger.error(f"Token validation error: {str(e)}")
|
84 |
+
raise credentials_exception
|
85 |
|
86 |
+
# For demonstration purposes, a simple login function
|
87 |
+
# In production, replace with proper user authentication (e.g., database lookup)
|
88 |
+
async def login(user_id: str) -> TokenResponse:
|
89 |
+
"""
|
90 |
+
Generate a token for a user. In production, validate credentials here.
|
91 |
+
"""
|
92 |
+
# Placeholder: Assume user_id is valid; in reality, check against a database
|
93 |
+
token = await create_access_token(user_id=user_id)
|
94 |
+
return TokenResponse(access_token=token, token_type="bearer")
|