sachin commited on
Commit
844386f
·
1 Parent(s): 643e32f

bearer-auth

Browse files
Files changed (4) hide show
  1. docs/menv.md +2 -2
  2. requirements.txt +2 -1
  3. src/server/main.py +39 -71
  4. src/server/utils/auth.py +84 -11
docs/menv.md CHANGED
@@ -4,6 +4,6 @@ export SPEECH_RATE_LIMIT=5/minute
4
  export CHAT_RATE_LIMIT=100/minute
5
  export EXTERNAL_TTS_URL=https://slabstech-dhwani-internal-api-server.hf.space/v1/audio/speech
6
  export EXTERNAL_ASR_URL=https://gaganyatri-asr-indic-server-cpu.hf.space
7
- export EXTERNAL_TEXT_GEN_URL=https://gaganyatri-asr-indic-server-cpu.hf.space
8
- export EXTERNAL_AUDIO_PROC_URL=https://gaganyatri-asr-indic-server-cpu.hf.space
9
  export API_KEY_SECRET=your_secret_key
 
4
  export CHAT_RATE_LIMIT=100/minute
5
  export EXTERNAL_TTS_URL=https://slabstech-dhwani-internal-api-server.hf.space/v1/audio/speech
6
  export EXTERNAL_ASR_URL=https://gaganyatri-asr-indic-server-cpu.hf.space
7
+ export EXTERNAL_TEXT_GEN_URL=https://slabstech-dhwani-internal-api-server.hf.space
8
+ export EXTERNAL_AUDIO_PROC_URL=https://slabstech-dhwani-internal-api-server.hf.space
9
  export API_KEY_SECRET=your_secret_key
requirements.txt CHANGED
@@ -4,4 +4,5 @@ pydantic_settings
4
  slowapi
5
  requests
6
  python-multipart
7
- pillow
 
 
4
  slowapi
5
  requests
6
  python-multipart
7
+ pillow
8
+ pyjwt
src/server/main.py CHANGED
@@ -9,40 +9,17 @@ from fastapi import Depends, FastAPI, File, HTTPException, Query, Request, Uploa
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
11
  from pydantic import BaseModel, Field, field_validator
12
- from pydantic_settings import BaseSettings
13
  from slowapi import Limiter
14
  from slowapi.util import get_remote_address
15
  import requests
16
  from PIL import Image
17
 
 
 
 
18
  # Assuming these are in your project structure
19
  from config.tts_config import SPEED, ResponseFormat, config as tts_config
20
  from config.logging_config import logger
21
- #from utils.auth import get_api_key
22
-
23
- # Configuration settings
24
- class Settings(BaseSettings):
25
- llm_model_name: str = "google/gemma-3-4b-it"
26
- max_tokens: int = 512
27
- host: str = "0.0.0.0"
28
- port: int = 7860
29
- chat_rate_limit: str = "100/minute"
30
- speech_rate_limit: str = "5/minute"
31
- external_tts_url: str = Field(..., env="EXTERNAL_TTS_URL")
32
- external_asr_url: str = Field(..., env="EXTERNAL_ASR_URL")
33
- external_text_gen_url: str = Field(..., env="EXTERNAL_TEXT_GEN_URL")
34
- external_audio_proc_url: str = Field(..., env="EXTERNAL_AUDIO_PROC_URL")
35
- api_key_secret: str = Field(..., env="API_KEY_SECRET")
36
-
37
- @field_validator("chat_rate_limit", "speech_rate_limit")
38
- def validate_rate_limit(cls, v):
39
- if not v.count("/") == 1 or not v.split("/")[0].isdigit():
40
- raise ValueError("Rate limit must be in format 'number/period' (e.g., '5/minute')")
41
- return v
42
-
43
- class Config:
44
- env_file = ".env"
45
- env_file_encoding = "utf-8"
46
 
47
  settings = Settings()
48
 
@@ -127,12 +104,17 @@ async def health_check():
127
  async def home():
128
  return RedirectResponse(url="/docs")
129
 
 
 
 
 
 
130
  @app.post("/v1/audio/speech")
131
  @limiter.limit(settings.speech_rate_limit)
132
  async def generate_audio(
133
  request: Request,
134
  speech_request: SpeechRequest = Depends(),
135
- #api_key: str = Depends(get_api_key),
136
  tts_service: TTSService = Depends(get_tts_service)
137
  ):
138
  if not speech_request.input.strip():
@@ -141,7 +123,8 @@ async def generate_audio(
141
  logger.info("Processing speech request", extra={
142
  "endpoint": "/v1/audio/speech",
143
  "input_length": len(speech_request.input),
144
- "client_ip": get_remote_address(request)
 
145
  })
146
 
147
  payload = {
@@ -167,10 +150,9 @@ async def generate_audio(
167
  headers=headers
168
  )
169
 
170
-
171
  class ChatRequest(BaseModel):
172
  prompt: str
173
- src_lang: str = "kan_Knda" # Default to Kannada
174
 
175
  @field_validator("prompt")
176
  def prompt_must_be_valid(cls, v):
@@ -181,22 +163,23 @@ class ChatRequest(BaseModel):
181
  class ChatResponse(BaseModel):
182
  response: str
183
 
184
-
185
  @app.post("/v1/chat", response_model=ChatResponse)
186
  @limiter.limit(settings.chat_rate_limit)
187
- async def chat(request: Request, chat_request: ChatRequest):
 
 
 
 
188
  if not chat_request.prompt:
189
  raise HTTPException(status_code=400, detail="Prompt cannot be empty")
190
- logger.info(f"Received prompt: {chat_request.prompt}, src_lang: {chat_request.src_lang}")
191
 
192
  try:
193
-
194
- # Call the external API instead of llm_manager.generate
195
  external_url = "https://slabstech-dhwani-internal-api-server.hf.space/v1/chat"
196
  payload = {
197
- "prompt": chat_request.prompt ,
198
- "src_lang": chat_request.src_lang,
199
- "tgt_lang" : chat_request.src_lang
200
  }
201
 
202
  response = requests.post(
@@ -208,14 +191,12 @@ async def chat(request: Request, chat_request: ChatRequest):
208
  },
209
  timeout=60
210
  )
211
- response.raise_for_status() # Raise an exception for bad status codes
212
 
213
- # Extract the response text from the API
214
  response_data = response.json()
215
- response = response_data.get("response", "")
216
- logger.info(f"Generated Chat response from external API: {response}")
217
-
218
- return ChatResponse(response=response)
219
 
220
  except requests.Timeout:
221
  logger.error("External chat API request timed out")
@@ -232,13 +213,14 @@ async def chat(request: Request, chat_request: ChatRequest):
232
  async def process_audio(
233
  file: UploadFile = File(...),
234
  language: str = Query(..., enum=["kannada", "hindi", "tamil"]),
235
- #api_key: str = Depends(get_api_key),
236
  request: Request = None,
237
  ):
238
  logger.info("Processing audio processing request", extra={
239
  "endpoint": "/v1/process_audio",
240
  "filename": file.filename,
241
- "client_ip": get_remote_address(request)
 
242
  })
243
 
244
  start_time = time()
@@ -269,16 +251,9 @@ async def process_audio(
269
  async def transcribe_audio(
270
  file: UploadFile = File(...),
271
  language: str = Query(..., enum=["kannada", "hindi", "tamil"]),
272
- #api_key: str = Depends(get_api_key),
273
  request: Request = None,
274
  ):
275
- '''
276
- logger.info("Processing transcription request", extra={
277
- "endpoint": "/v1/transcribe",
278
- "filename": file.filename,
279
- "client_ip": get_remote_address(request)
280
- })
281
- '''
282
  start_time = time()
283
  try:
284
  file_content = await file.read()
@@ -294,13 +269,11 @@ async def transcribe_audio(
294
  response.raise_for_status()
295
 
296
  transcription = response.json().get("text", "")
297
- #logger.info(f"Transcription completed in {time() - start_time:.2f} seconds")
298
  return TranscriptionResponse(text=transcription)
299
 
300
  except requests.Timeout:
301
  raise HTTPException(status_code=504, detail="Transcription service timeout")
302
  except requests.RequestException as e:
303
- #logger.error(f"Transcription request failed: {str(e)}")
304
  raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
305
 
306
  @app.post("/v1/chat_v2", response_model=TranscriptionResponse)
@@ -309,7 +282,7 @@ async def chat_v2(
309
  request: Request,
310
  prompt: str = Form(...),
311
  image: UploadFile = File(default=None),
312
- #api_key: str = Depends(get_api_key)
313
  ):
314
  if not prompt:
315
  raise HTTPException(status_code=400, detail="Prompt cannot be empty")
@@ -318,18 +291,18 @@ async def chat_v2(
318
  "endpoint": "/v1/chat_v2",
319
  "prompt_length": len(prompt),
320
  "has_image": bool(image),
321
- "client_ip": get_remote_address(request)
 
322
  })
323
 
324
  try:
325
- # For demonstration, we'll just return the prompt as text
326
  image_data = Image.open(await image.read()) if image else None
327
  response_text = f"Processed: {prompt}" + (" with image" if image_data else "")
328
  return TranscriptionResponse(text=response_text)
329
  except Exception as e:
330
  logger.error(f"Chat_v2 processing failed: {str(e)}", exc_info=True)
331
  raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
332
-
333
  class TranslationRequest(BaseModel):
334
  sentences: list[str]
335
  src_lang: str
@@ -339,13 +312,14 @@ class TranslationResponse(BaseModel):
339
  translations: list[str]
340
 
341
  @app.post("/v1/translate", response_model=TranslationResponse)
342
- async def translate(request: TranslationRequest):
343
- logger.info(f"Received translation request: {request.dict()}")
 
 
 
344
 
345
- # External API endpoint
346
  external_url = f"https://slabstech-dhwani-internal-api-server.hf.space/translate?src_lang={request.src_lang}&tgt_lang={request.tgt_lang}"
347
 
348
- # Prepare the payload matching the external API's expected format
349
  payload = {
350
  "sentences": request.sentences,
351
  "src_lang": request.src_lang,
@@ -353,7 +327,6 @@ async def translate(request: TranslationRequest):
353
  }
354
 
355
  try:
356
- # Make the POST request to the external API
357
  response = requests.post(
358
  external_url,
359
  json=payload,
@@ -361,13 +334,10 @@ async def translate(request: TranslationRequest):
361
  "accept": "application/json",
362
  "Content-Type": "application/json"
363
  },
364
- timeout=60 # Set a timeout to avoid hanging
365
  )
366
-
367
- # Raise an exception for bad status codes (4xx, 5xx)
368
  response.raise_for_status()
369
 
370
- # Extract translations from the response
371
  response_data = response.json()
372
  translations = response_data.get("translations", [])
373
 
@@ -388,8 +358,6 @@ async def translate(request: TranslationRequest):
388
  logger.error(f"Invalid JSON response: {str(e)}")
389
  raise HTTPException(status_code=500, detail="Invalid response format from translation service")
390
 
391
-
392
-
393
  if __name__ == "__main__":
394
  parser = argparse.ArgumentParser(description="Run the FastAPI server.")
395
  parser.add_argument("--port", type=int, default=settings.port, help="Port to run the server on.")
 
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
11
  from pydantic import BaseModel, Field, field_validator
 
12
  from slowapi import Limiter
13
  from slowapi.util import get_remote_address
14
  import requests
15
  from PIL import Image
16
 
17
+ # Import from auth.py
18
+ from utils.auth import get_current_user, login, TokenResponse, Settings
19
+
20
  # Assuming these are in your project structure
21
  from config.tts_config import SPEED, ResponseFormat, config as tts_config
22
  from config.logging_config import logger
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  settings = Settings()
25
 
 
104
  async def home():
105
  return RedirectResponse(url="/docs")
106
 
107
+ @app.post("/v1/token", response_model=TokenResponse)
108
+ async def token(user_id: str = Form(...)):
109
+ # In production, add proper authentication (e.g., password validation)
110
+ return await login(user_id=user_id)
111
+
112
  @app.post("/v1/audio/speech")
113
  @limiter.limit(settings.speech_rate_limit)
114
  async def generate_audio(
115
  request: Request,
116
  speech_request: SpeechRequest = Depends(),
117
+ user_id: str = Depends(get_current_user),
118
  tts_service: TTSService = Depends(get_tts_service)
119
  ):
120
  if not speech_request.input.strip():
 
123
  logger.info("Processing speech request", extra={
124
  "endpoint": "/v1/audio/speech",
125
  "input_length": len(speech_request.input),
126
+ "client_ip": get_remote_address(request),
127
+ "user_id": user_id
128
  })
129
 
130
  payload = {
 
150
  headers=headers
151
  )
152
 
 
153
  class ChatRequest(BaseModel):
154
  prompt: str
155
+ src_lang: str = "kan_Knda"
156
 
157
  @field_validator("prompt")
158
  def prompt_must_be_valid(cls, v):
 
163
  class ChatResponse(BaseModel):
164
  response: str
165
 
 
166
  @app.post("/v1/chat", response_model=ChatResponse)
167
  @limiter.limit(settings.chat_rate_limit)
168
+ async def chat(
169
+ request: Request,
170
+ chat_request: ChatRequest,
171
+ user_id: str = Depends(get_current_user)
172
+ ):
173
  if not chat_request.prompt:
174
  raise HTTPException(status_code=400, detail="Prompt cannot be empty")
175
+ logger.info(f"Received prompt: {chat_request.prompt}, src_lang: {chat_request.src_lang}, user_id: {user_id}")
176
 
177
  try:
 
 
178
  external_url = "https://slabstech-dhwani-internal-api-server.hf.space/v1/chat"
179
  payload = {
180
+ "prompt": chat_request.prompt,
181
+ "src_lang": chat_request.src_lang,
182
+ "tgt_lang": chat_request.src_lang
183
  }
184
 
185
  response = requests.post(
 
191
  },
192
  timeout=60
193
  )
194
+ response.raise_for_status()
195
 
 
196
  response_data = response.json()
197
+ response_text = response_data.get("response", "")
198
+ logger.info(f"Generated Chat response from external API: {response_text}")
199
+ return ChatResponse(response=response_text)
 
200
 
201
  except requests.Timeout:
202
  logger.error("External chat API request timed out")
 
213
  async def process_audio(
214
  file: UploadFile = File(...),
215
  language: str = Query(..., enum=["kannada", "hindi", "tamil"]),
216
+ user_id: str = Depends(get_current_user),
217
  request: Request = None,
218
  ):
219
  logger.info("Processing audio processing request", extra={
220
  "endpoint": "/v1/process_audio",
221
  "filename": file.filename,
222
+ "client_ip": get_remote_address(request),
223
+ "user_id": user_id
224
  })
225
 
226
  start_time = time()
 
251
  async def transcribe_audio(
252
  file: UploadFile = File(...),
253
  language: str = Query(..., enum=["kannada", "hindi", "tamil"]),
254
+ user_id: str = Depends(get_current_user),
255
  request: Request = None,
256
  ):
 
 
 
 
 
 
 
257
  start_time = time()
258
  try:
259
  file_content = await file.read()
 
269
  response.raise_for_status()
270
 
271
  transcription = response.json().get("text", "")
 
272
  return TranscriptionResponse(text=transcription)
273
 
274
  except requests.Timeout:
275
  raise HTTPException(status_code=504, detail="Transcription service timeout")
276
  except requests.RequestException as e:
 
277
  raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
278
 
279
  @app.post("/v1/chat_v2", response_model=TranscriptionResponse)
 
282
  request: Request,
283
  prompt: str = Form(...),
284
  image: UploadFile = File(default=None),
285
+ user_id: str = Depends(get_current_user)
286
  ):
287
  if not prompt:
288
  raise HTTPException(status_code=400, detail="Prompt cannot be empty")
 
291
  "endpoint": "/v1/chat_v2",
292
  "prompt_length": len(prompt),
293
  "has_image": bool(image),
294
+ "client_ip": get_remote_address(request),
295
+ "user_id": user_id
296
  })
297
 
298
  try:
 
299
  image_data = Image.open(await image.read()) if image else None
300
  response_text = f"Processed: {prompt}" + (" with image" if image_data else "")
301
  return TranscriptionResponse(text=response_text)
302
  except Exception as e:
303
  logger.error(f"Chat_v2 processing failed: {str(e)}", exc_info=True)
304
  raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
305
+
306
  class TranslationRequest(BaseModel):
307
  sentences: list[str]
308
  src_lang: str
 
312
  translations: list[str]
313
 
314
  @app.post("/v1/translate", response_model=TranslationResponse)
315
+ async def translate(
316
+ request: TranslationRequest,
317
+ user_id: str = Depends(get_current_user)
318
+ ):
319
+ logger.info(f"Received translation request: {request.dict()}, user_id: {user_id}")
320
 
 
321
  external_url = f"https://slabstech-dhwani-internal-api-server.hf.space/translate?src_lang={request.src_lang}&tgt_lang={request.tgt_lang}"
322
 
 
323
  payload = {
324
  "sentences": request.sentences,
325
  "src_lang": request.src_lang,
 
327
  }
328
 
329
  try:
 
330
  response = requests.post(
331
  external_url,
332
  json=payload,
 
334
  "accept": "application/json",
335
  "Content-Type": "application/json"
336
  },
337
+ timeout=60
338
  )
 
 
339
  response.raise_for_status()
340
 
 
341
  response_data = response.json()
342
  translations = response_data.get("translations", [])
343
 
 
358
  logger.error(f"Invalid JSON response: {str(e)}")
359
  raise HTTPException(status_code=500, detail="Invalid response format from translation service")
360
 
 
 
361
  if __name__ == "__main__":
362
  parser = argparse.ArgumentParser(description="Run the FastAPI server.")
363
  parser.add_argument("--port", type=int, default=settings.port, help="Port to run the server on.")
src/server/utils/auth.py CHANGED
@@ -1,21 +1,94 @@
1
- from fastapi.security import APIKeyHeader
 
 
 
 
2
  from fastapi import HTTPException, status, Depends
 
3
  from pydantic_settings import BaseSettings
4
- from config.logging_config import logger
 
5
 
 
6
  class Settings(BaseSettings):
7
- api_key: str
 
 
 
 
 
 
 
 
8
  class Config:
9
  env_file = ".env"
 
10
 
11
  settings = Settings()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- API_KEY_NAME = "X-API-Key"
14
- api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- async def get_api_key(api_key: str = Depends(api_key_header)):
17
- if api_key != settings.api_key:
18
- logger.warning(f"Failed API key attempt: {api_key}")
19
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API Key")
20
- logger.info("API key validated successfully")
21
- return api_key
 
 
 
 
1
+ import jwt
2
+ from datetime import datetime, timedelta
3
+ from pydantic import BaseModel, Field, field_validator
4
+
5
+ from fastapi.security import OAuth2PasswordBearer
6
  from fastapi import HTTPException, status, Depends
7
+ from pydantic import BaseModel
8
  from pydantic_settings import BaseSettings
9
+ from config.logging_config import logger # Assuming this is available
10
+ from typing import Optional
11
 
12
+ # Centralized Settings class (can be moved to a separate config file later)
13
  class Settings(BaseSettings):
14
+ api_key_secret: str = Field(..., env="API_KEY_SECRET") # Secret key for signing JWTs
15
+ token_expiration_minutes: int = Field(30, env="TOKEN_EXPIRATION_MINUTES") # Default to 30 minutes
16
+ llm_model_name: str = "google/gemma-3-4b-it"
17
+ max_tokens: int = 512
18
+ host: str = "0.0.0.0"
19
+ port: int = 7860
20
+ chat_rate_limit: str = "100/minute"
21
+ speech_rate_limit: str = "5/minute"
22
+
23
  class Config:
24
  env_file = ".env"
25
+ env_file_encoding = "utf-8"
26
 
27
  settings = Settings()
28
+ logger.info(f"Loaded API_KEY_SECRET at startup: {settings.api_key_secret}") # Add this line
29
+
30
+ # OAuth2 scheme with Bearer token
31
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/v1/token")
32
+
33
+ # Model for token payload
34
+ class TokenPayload(BaseModel):
35
+ sub: str # Subject (user identifier)
36
+ exp: int # Expiration timestamp
37
+
38
+ # Model for token response
39
+ class TokenResponse(BaseModel):
40
+ access_token: str
41
+ token_type: str
42
+
43
+ async def create_access_token(user_id: str) -> str:
44
+ """
45
+ Create a JWT access token for a given user.
46
+ """
47
+ expire = datetime.utcnow() + timedelta(minutes=settings.token_expiration_minutes)
48
+ payload = {"sub": user_id, "exp": expire.timestamp()}
49
+ logger.info(f"Signing token with API_KEY_SECRET: {settings.api_key_secret}") # Add this line
50
+ token = jwt.encode(payload, settings.api_key_secret, algorithm="HS256")
51
+ logger.info(f"Generated access token for user: {user_id}")
52
+ return token
53
 
54
+ async def get_current_user(token: str = Depends(oauth2_scheme)) -> str:
55
+ """
56
+ Validate the Bearer token and return the user ID.
57
+ """
58
+ credentials_exception = HTTPException(
59
+ status_code=status.HTTP_401_UNAUTHORIZED,
60
+ detail="Invalid authentication credentials",
61
+ headers={"WWW-Authenticate": "Bearer"},
62
+ )
63
+
64
+ try:
65
+ logger.info(f"Verifying token with API_KEY_SECRET: {settings.api_key_secret}") # Add this line
66
+ payload = jwt.decode(token, settings.api_key_secret, algorithms=["HS256"])
67
+ token_data = TokenPayload(**payload)
68
+ user_id = token_data.sub
69
+ if user_id is None:
70
+ raise credentials_exception
71
+ if datetime.utcnow().timestamp() > token_data.exp:
72
+ raise HTTPException(
73
+ status_code=status.HTTP_401_UNAUTHORIZED,
74
+ detail="Token has expired",
75
+ headers={"WWW-Authenticate": "Bearer"},
76
+ )
77
+ logger.info(f"Validated token for user: {user_id}")
78
+ return user_id
79
+ except jwt.InvalidTokenError:
80
+ logger.warning(f"Invalid token attempt: {token[:10]}...")
81
+ raise credentials_exception
82
+ except Exception as e:
83
+ logger.error(f"Token validation error: {str(e)}")
84
+ raise credentials_exception
85
 
86
+ # For demonstration purposes, a simple login function
87
+ # In production, replace with proper user authentication (e.g., database lookup)
88
+ async def login(user_id: str) -> TokenResponse:
89
+ """
90
+ Generate a token for a user. In production, validate credentials here.
91
+ """
92
+ # Placeholder: Assume user_id is valid; in reality, check against a database
93
+ token = await create_access_token(user_id=user_id)
94
+ return TokenResponse(access_token=token, token_type="bearer")