sachin commited on
Commit
26b1877
·
1 Parent(s): 0b7d0b0

autehntication

Browse files
Files changed (2) hide show
  1. src/server/main.py +27 -22
  2. src/server/utils/auth.py +27 -10
src/server/main.py CHANGED
@@ -8,7 +8,7 @@ import uvicorn
8
  from fastapi import Depends, FastAPI, File, HTTPException, Query, Request, UploadFile, Form, Security
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
11
- from fastapi.security import OAuth2PasswordBearer
12
  from pydantic import BaseModel, field_validator, Field
13
  from slowapi import Limiter
14
  from slowapi.util import get_remote_address
@@ -16,7 +16,7 @@ import requests
16
  from PIL import Image
17
 
18
  # Import from auth.py
19
- from utils.auth import get_current_user, login, refresh_token, register, TokenResponse, Settings, LoginRequest, RegisterRequest
20
 
21
  # Assuming these are in your project structure
22
  from config.tts_config import SPEED, ResponseFormat, config as tts_config
@@ -24,18 +24,15 @@ from config.logging_config import logger
24
 
25
  settings = Settings()
26
 
27
- # Define OAuth2 scheme explicitly for Swagger
28
- oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/v1/token")
29
-
30
  # FastAPI app setup with enhanced docs
31
  app = FastAPI(
32
  title="Dhwani API",
33
  description="A multilingual AI-powered API supporting Indian languages for chat, text-to-speech, audio processing, and transcription. "
34
  "**Authentication Guide:** \n"
35
- "1. Register a new user via `/v1/register` with a POST request containing `username` and `password`. \n"
36
  "2. Obtain an access token by sending a POST request to `/v1/token` with `username` and `password`. \n"
37
- "3. Click the 'Authorize' button (top-right), enter `Bearer <your_access_token>` in the 'bearerAuth' field, and click 'Authorize'. \n"
38
- "All protected endpoints require this token for access.",
39
  version="1.0.0",
40
  redirect_slashes=False,
41
  openapi_tags=[
@@ -45,7 +42,6 @@ app = FastAPI(
45
  {"name": "Audio", "description": "Audio processing and TTS endpoints"},
46
  {"name": "Translation", "description": "Text translation endpoints"}
47
  ],
48
- swagger_ui_oauth2_redirect_url="/docs/oauth2-redirect",
49
  )
50
 
51
  app.add_middleware(
@@ -170,20 +166,24 @@ async def token(login_request: LoginRequest):
170
  200: {"description": "New token issued", "model": TokenResponse},
171
  401: {"description": "Invalid or expired token"}
172
  })
173
- async def refresh(token_response: TokenResponse = Depends(refresh_token)):
174
- return token_response
175
 
176
  @app.post("/v1/register",
177
  response_model=TokenResponse,
178
  summary="Register New User",
179
- description="Create a new user account and return an access token. Use this token for subsequent requests.",
180
  tags=["Authentication"],
181
  responses={
182
  200: {"description": "User registered successfully", "model": TokenResponse},
183
- 400: {"description": "Username already exists"}
 
184
  })
185
- async def register_user(register_request: RegisterRequest):
186
- return await register(register_request)
 
 
 
187
 
188
  @app.post("/v1/audio/speech",
189
  summary="Generate Speech from Text",
@@ -200,9 +200,10 @@ async def register_user(register_request: RegisterRequest):
200
  async def generate_audio(
201
  request: Request,
202
  speech_request: SpeechRequest = Depends(),
203
- user_id: str = Depends(get_current_user),
204
  tts_service: TTSService = Depends(get_tts_service)
205
  ):
 
206
  if not speech_request.input.strip():
207
  raise HTTPException(status_code=400, detail="Input cannot be empty")
208
 
@@ -276,8 +277,9 @@ class ChatResponse(BaseModel):
276
  async def chat(
277
  request: Request,
278
  chat_request: ChatRequest,
279
- user_id: str = Depends(get_current_user)
280
  ):
 
281
  if not chat_request.prompt:
282
  raise HTTPException(status_code=400, detail="Prompt cannot be empty")
283
  logger.info(f"Received prompt: {chat_request.prompt}, src_lang: {chat_request.src_lang}, user_id: {user_id}")
@@ -332,8 +334,9 @@ async def process_audio(
332
  request: Request,
333
  file: UploadFile = File(..., description="Audio file to process"),
334
  language: str = Query(..., enum=["kannada", "hindi", "tamil"], description="Language of the audio"),
335
- user_id: str = Depends(get_current_user),
336
  ):
 
337
  logger.info("Processing audio processing request", extra={
338
  "endpoint": "/v1/process_audio",
339
  "filename": file.filename,
@@ -378,9 +381,9 @@ async def process_audio(
378
  async def transcribe_audio(
379
  file: UploadFile = File(..., description="Audio file to transcribe"),
380
  language: str = Query(..., enum=["kannada", "hindi", "tamil"], description="Language of the audio"),
381
- user_id: str = Depends(get_current_user),
382
- request: Request = None,
383
  ):
 
384
  start_time = time()
385
  try:
386
  file_content = await file.read()
@@ -419,8 +422,9 @@ async def chat_v2(
419
  request: Request,
420
  prompt: str = Form(..., description="Text prompt for chat"),
421
  image: UploadFile = File(default=None, description="Optional image to accompany the prompt"),
422
- user_id: str = Depends(get_current_user)
423
  ):
 
424
  if not prompt:
425
  raise HTTPException(status_code=400, detail="Prompt cannot be empty")
426
 
@@ -473,8 +477,9 @@ class TranslationResponse(BaseModel):
473
  })
474
  async def translate(
475
  request: TranslationRequest,
476
- user_id: str = Depends(get_current_user)
477
  ):
 
478
  logger.info(f"Received translation request: {request.dict()}, user_id: {user_id}")
479
 
480
  external_url = f"https://slabstech-dhwani-internal-api-server.hf.space/translate?src_lang={request.src_lang}&tgt_lang={request.tgt_lang}"
 
8
  from fastapi import Depends, FastAPI, File, HTTPException, Query, Request, UploadFile, Form, Security
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
11
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
12
  from pydantic import BaseModel, field_validator, Field
13
  from slowapi import Limiter
14
  from slowapi.util import get_remote_address
 
16
  from PIL import Image
17
 
18
  # Import from auth.py
19
+ from utils.auth import get_current_user, get_current_user_with_admin, login, refresh_token, register, TokenResponse, Settings, LoginRequest, RegisterRequest, bearer_scheme
20
 
21
  # Assuming these are in your project structure
22
  from config.tts_config import SPEED, ResponseFormat, config as tts_config
 
24
 
25
  settings = Settings()
26
 
 
 
 
27
  # FastAPI app setup with enhanced docs
28
  app = FastAPI(
29
  title="Dhwani API",
30
  description="A multilingual AI-powered API supporting Indian languages for chat, text-to-speech, audio processing, and transcription. "
31
  "**Authentication Guide:** \n"
32
+ "1. Register a new user via `/v1/register` with a POST request containing `username` and `password` (requires admin access). \n"
33
  "2. Obtain an access token by sending a POST request to `/v1/token` with `username` and `password`. \n"
34
+ "3. Click the 'Authorize' button (top-right), enter your access token (e.g., `your_access_token`) in the 'bearerAuth' field, and click 'Authorize'. \n"
35
+ "All protected endpoints require this token for access. Only the 'admin' user (default password: adminpass) can register new users.",
36
  version="1.0.0",
37
  redirect_slashes=False,
38
  openapi_tags=[
 
42
  {"name": "Audio", "description": "Audio processing and TTS endpoints"},
43
  {"name": "Translation", "description": "Text translation endpoints"}
44
  ],
 
45
  )
46
 
47
  app.add_middleware(
 
166
  200: {"description": "New token issued", "model": TokenResponse},
167
  401: {"description": "Invalid or expired token"}
168
  })
169
+ async def refresh(credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)):
170
+ return await refresh_token(credentials)
171
 
172
  @app.post("/v1/register",
173
  response_model=TokenResponse,
174
  summary="Register New User",
175
+ description="Create a new user account and return an access token. Requires admin access (use 'admin' user with password 'adminpass' initially).",
176
  tags=["Authentication"],
177
  responses={
178
  200: {"description": "User registered successfully", "model": TokenResponse},
179
+ 400: {"description": "Username already exists"},
180
+ 403: {"description": "Admin access required"}
181
  })
182
+ async def register_user(
183
+ register_request: RegisterRequest,
184
+ current_user: str = Depends(get_current_user_with_admin) # Enforce admin-only access
185
+ ):
186
+ return await register(register_request, current_user) # Pass current_user explicitly
187
 
188
  @app.post("/v1/audio/speech",
189
  summary="Generate Speech from Text",
 
200
  async def generate_audio(
201
  request: Request,
202
  speech_request: SpeechRequest = Depends(),
203
+ credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme),
204
  tts_service: TTSService = Depends(get_tts_service)
205
  ):
206
+ user_id = await get_current_user(credentials)
207
  if not speech_request.input.strip():
208
  raise HTTPException(status_code=400, detail="Input cannot be empty")
209
 
 
277
  async def chat(
278
  request: Request,
279
  chat_request: ChatRequest,
280
+ credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)
281
  ):
282
+ user_id = await get_current_user(credentials)
283
  if not chat_request.prompt:
284
  raise HTTPException(status_code=400, detail="Prompt cannot be empty")
285
  logger.info(f"Received prompt: {chat_request.prompt}, src_lang: {chat_request.src_lang}, user_id: {user_id}")
 
334
  request: Request,
335
  file: UploadFile = File(..., description="Audio file to process"),
336
  language: str = Query(..., enum=["kannada", "hindi", "tamil"], description="Language of the audio"),
337
+ credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)
338
  ):
339
+ user_id = await get_current_user(credentials)
340
  logger.info("Processing audio processing request", extra={
341
  "endpoint": "/v1/process_audio",
342
  "filename": file.filename,
 
381
  async def transcribe_audio(
382
  file: UploadFile = File(..., description="Audio file to transcribe"),
383
  language: str = Query(..., enum=["kannada", "hindi", "tamil"], description="Language of the audio"),
384
+ credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)
 
385
  ):
386
+ user_id = await get_current_user(credentials)
387
  start_time = time()
388
  try:
389
  file_content = await file.read()
 
422
  request: Request,
423
  prompt: str = Form(..., description="Text prompt for chat"),
424
  image: UploadFile = File(default=None, description="Optional image to accompany the prompt"),
425
+ credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)
426
  ):
427
+ user_id = await get_current_user(credentials)
428
  if not prompt:
429
  raise HTTPException(status_code=400, detail="Prompt cannot be empty")
430
 
 
477
  })
478
  async def translate(
479
  request: TranslationRequest,
480
+ credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)
481
  ):
482
+ user_id = await get_current_user(credentials)
483
  logger.info(f"Received translation request: {request.dict()}, user_id: {user_id}")
484
 
485
  external_url = f"https://slabstech-dhwani-internal-api-server.hf.space/translate?src_lang={request.src_lang}&tgt_lang={request.tgt_lang}"
src/server/utils/auth.py CHANGED
@@ -1,11 +1,11 @@
1
  import jwt
2
  from datetime import datetime, timedelta
3
- from fastapi.security import OAuth2PasswordBearer
4
  from fastapi import HTTPException, status, Depends
 
5
  from pydantic import BaseModel, Field
6
  from pydantic_settings import BaseSettings
7
  from config.logging_config import logger
8
- from sqlalchemy import create_engine, Column, String
9
  from sqlalchemy.ext.declarative import declarative_base
10
  from sqlalchemy.orm import sessionmaker
11
  from passlib.context import CryptContext
@@ -20,6 +20,7 @@ class User(Base):
20
  __tablename__ = "users"
21
  username = Column(String, primary_key=True, index=True)
22
  password = Column(String) # Stores hashed passwords
 
23
 
24
  Base.metadata.create_all(bind=engine)
25
 
@@ -31,7 +32,11 @@ def seed_initial_data():
31
  db = SessionLocal()
32
  if not db.query(User).filter_by(username="testuser").first():
33
  hashed_password = pwd_context.hash("password123")
34
- db.add(User(username="testuser", password=hashed_password))
 
 
 
 
35
  db.commit()
36
  db.close()
37
 
@@ -58,7 +63,8 @@ class Settings(BaseSettings):
58
  settings = Settings()
59
  logger.info(f"Loaded API_KEY_SECRET at startup: {settings.api_key_secret}")
60
 
61
- oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/v1/token")
 
62
 
63
  class TokenPayload(BaseModel):
64
  sub: str
@@ -84,7 +90,8 @@ async def create_access_token(user_id: str) -> str:
84
  logger.info(f"Generated access token for user: {user_id}")
85
  return token
86
 
87
- async def get_current_user(token: str = Depends(oauth2_scheme)) -> str:
 
88
  credentials_exception = HTTPException(
89
  status_code=status.HTTP_401_UNAUTHORIZED,
90
  detail="Invalid authentication credentials",
@@ -127,6 +134,16 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> str:
127
  logger.error(f"Unexpected token validation error: {str(e)}")
128
  raise credentials_exception
129
 
 
 
 
 
 
 
 
 
 
 
130
  async def login(login_request: LoginRequest) -> TokenResponse:
131
  db = SessionLocal()
132
  user = db.query(User).filter_by(username=login_request.username).first()
@@ -137,7 +154,7 @@ async def login(login_request: LoginRequest) -> TokenResponse:
137
  token = await create_access_token(user_id=user.username)
138
  return TokenResponse(access_token=token, token_type="bearer")
139
 
140
- async def register(register_request: RegisterRequest) -> TokenResponse:
141
  db = SessionLocal()
142
  existing_user = db.query(User).filter_by(username=register_request.username).first()
143
  if existing_user:
@@ -146,16 +163,16 @@ async def register(register_request: RegisterRequest) -> TokenResponse:
146
  raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Username already exists")
147
 
148
  hashed_password = pwd_context.hash(register_request.password)
149
- new_user = User(username=register_request.username, password=hashed_password)
150
  db.add(new_user)
151
  db.commit()
152
  db.close()
153
 
154
  token = await create_access_token(user_id=register_request.username)
155
- logger.info(f"Registered and generated token for user: {register_request.username}")
156
  return TokenResponse(access_token=token, token_type="bearer")
157
 
158
- async def refresh_token(token: str = Depends(oauth2_scheme)) -> TokenResponse:
159
- user_id = await get_current_user(token)
160
  new_token = await create_access_token(user_id=user_id)
161
  return TokenResponse(access_token=new_token, token_type="bearer")
 
1
  import jwt
2
  from datetime import datetime, timedelta
 
3
  from fastapi import HTTPException, status, Depends
4
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
5
  from pydantic import BaseModel, Field
6
  from pydantic_settings import BaseSettings
7
  from config.logging_config import logger
8
+ from sqlalchemy import create_engine, Column, String, Boolean
9
  from sqlalchemy.ext.declarative import declarative_base
10
  from sqlalchemy.orm import sessionmaker
11
  from passlib.context import CryptContext
 
20
  __tablename__ = "users"
21
  username = Column(String, primary_key=True, index=True)
22
  password = Column(String) # Stores hashed passwords
23
+ is_admin = Column(Boolean, default=False) # New admin flag
24
 
25
  Base.metadata.create_all(bind=engine)
26
 
 
32
  db = SessionLocal()
33
  if not db.query(User).filter_by(username="testuser").first():
34
  hashed_password = pwd_context.hash("password123")
35
+ db.add(User(username="testuser", password=hashed_password, is_admin=False))
36
+ db.commit()
37
+ if not db.query(User).filter_by(username="admin").first():
38
+ hashed_password = pwd_context.hash("adminpass")
39
+ db.add(User(username="admin", password=hashed_password, is_admin=True))
40
  db.commit()
41
  db.close()
42
 
 
63
  settings = Settings()
64
  logger.info(f"Loaded API_KEY_SECRET at startup: {settings.api_key_secret}")
65
 
66
+ # Use HTTPBearer
67
+ bearer_scheme = HTTPBearer()
68
 
69
  class TokenPayload(BaseModel):
70
  sub: str
 
90
  logger.info(f"Generated access token for user: {user_id}")
91
  return token
92
 
93
+ async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)) -> str:
94
+ token = credentials.credentials
95
  credentials_exception = HTTPException(
96
  status_code=status.HTTP_401_UNAUTHORIZED,
97
  detail="Invalid authentication credentials",
 
134
  logger.error(f"Unexpected token validation error: {str(e)}")
135
  raise credentials_exception
136
 
137
+ async def get_current_user_with_admin(credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)) -> str:
138
+ user_id = await get_current_user(credentials)
139
+ db = SessionLocal()
140
+ user = db.query(User).filter_by(username=user_id).first()
141
+ db.close()
142
+ if not user or not user.is_admin:
143
+ logger.warning(f"User {user_id} is not authorized as admin")
144
+ raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin access required")
145
+ return user_id
146
+
147
  async def login(login_request: LoginRequest) -> TokenResponse:
148
  db = SessionLocal()
149
  user = db.query(User).filter_by(username=login_request.username).first()
 
154
  token = await create_access_token(user_id=user.username)
155
  return TokenResponse(access_token=token, token_type="bearer")
156
 
157
+ async def register(register_request: RegisterRequest, current_user: str = Depends(get_current_user_with_admin)) -> TokenResponse:
158
  db = SessionLocal()
159
  existing_user = db.query(User).filter_by(username=register_request.username).first()
160
  if existing_user:
 
163
  raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Username already exists")
164
 
165
  hashed_password = pwd_context.hash(register_request.password)
166
+ new_user = User(username=register_request.username, password=hashed_password, is_admin=False)
167
  db.add(new_user)
168
  db.commit()
169
  db.close()
170
 
171
  token = await create_access_token(user_id=register_request.username)
172
+ logger.info(f"Registered and generated token for user: {register_request.username} by admin {current_user}")
173
  return TokenResponse(access_token=token, token_type="bearer")
174
 
175
+ async def refresh_token(credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)) -> TokenResponse:
176
+ user_id = await get_current_user(credentials)
177
  new_token = await create_access_token(user_id=user_id)
178
  return TokenResponse(access_token=new_token, token_type="bearer")