Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
sachin
commited on
Commit
·
26b1877
1
Parent(s):
0b7d0b0
autehntication
Browse files- src/server/main.py +27 -22
- src/server/utils/auth.py +27 -10
src/server/main.py
CHANGED
@@ -8,7 +8,7 @@ import uvicorn
|
|
8 |
from fastapi import Depends, FastAPI, File, HTTPException, Query, Request, UploadFile, Form, Security
|
9 |
from fastapi.middleware.cors import CORSMiddleware
|
10 |
from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
|
11 |
-
from fastapi.security import
|
12 |
from pydantic import BaseModel, field_validator, Field
|
13 |
from slowapi import Limiter
|
14 |
from slowapi.util import get_remote_address
|
@@ -16,7 +16,7 @@ import requests
|
|
16 |
from PIL import Image
|
17 |
|
18 |
# Import from auth.py
|
19 |
-
from utils.auth import get_current_user, login, refresh_token, register, TokenResponse, Settings, LoginRequest, RegisterRequest
|
20 |
|
21 |
# Assuming these are in your project structure
|
22 |
from config.tts_config import SPEED, ResponseFormat, config as tts_config
|
@@ -24,18 +24,15 @@ from config.logging_config import logger
|
|
24 |
|
25 |
settings = Settings()
|
26 |
|
27 |
-
# Define OAuth2 scheme explicitly for Swagger
|
28 |
-
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/v1/token")
|
29 |
-
|
30 |
# FastAPI app setup with enhanced docs
|
31 |
app = FastAPI(
|
32 |
title="Dhwani API",
|
33 |
description="A multilingual AI-powered API supporting Indian languages for chat, text-to-speech, audio processing, and transcription. "
|
34 |
"**Authentication Guide:** \n"
|
35 |
-
"1. Register a new user via `/v1/register` with a POST request containing `username` and `password
|
36 |
"2. Obtain an access token by sending a POST request to `/v1/token` with `username` and `password`. \n"
|
37 |
-
"3. Click the 'Authorize' button (top-right), enter `
|
38 |
-
"All protected endpoints require this token for access.",
|
39 |
version="1.0.0",
|
40 |
redirect_slashes=False,
|
41 |
openapi_tags=[
|
@@ -45,7 +42,6 @@ app = FastAPI(
|
|
45 |
{"name": "Audio", "description": "Audio processing and TTS endpoints"},
|
46 |
{"name": "Translation", "description": "Text translation endpoints"}
|
47 |
],
|
48 |
-
swagger_ui_oauth2_redirect_url="/docs/oauth2-redirect",
|
49 |
)
|
50 |
|
51 |
app.add_middleware(
|
@@ -170,20 +166,24 @@ async def token(login_request: LoginRequest):
|
|
170 |
200: {"description": "New token issued", "model": TokenResponse},
|
171 |
401: {"description": "Invalid or expired token"}
|
172 |
})
|
173 |
-
async def refresh(
|
174 |
-
return
|
175 |
|
176 |
@app.post("/v1/register",
|
177 |
response_model=TokenResponse,
|
178 |
summary="Register New User",
|
179 |
-
description="Create a new user account and return an access token.
|
180 |
tags=["Authentication"],
|
181 |
responses={
|
182 |
200: {"description": "User registered successfully", "model": TokenResponse},
|
183 |
-
400: {"description": "Username already exists"}
|
|
|
184 |
})
|
185 |
-
async def register_user(
|
186 |
-
|
|
|
|
|
|
|
187 |
|
188 |
@app.post("/v1/audio/speech",
|
189 |
summary="Generate Speech from Text",
|
@@ -200,9 +200,10 @@ async def register_user(register_request: RegisterRequest):
|
|
200 |
async def generate_audio(
|
201 |
request: Request,
|
202 |
speech_request: SpeechRequest = Depends(),
|
203 |
-
|
204 |
tts_service: TTSService = Depends(get_tts_service)
|
205 |
):
|
|
|
206 |
if not speech_request.input.strip():
|
207 |
raise HTTPException(status_code=400, detail="Input cannot be empty")
|
208 |
|
@@ -276,8 +277,9 @@ class ChatResponse(BaseModel):
|
|
276 |
async def chat(
|
277 |
request: Request,
|
278 |
chat_request: ChatRequest,
|
279 |
-
|
280 |
):
|
|
|
281 |
if not chat_request.prompt:
|
282 |
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
|
283 |
logger.info(f"Received prompt: {chat_request.prompt}, src_lang: {chat_request.src_lang}, user_id: {user_id}")
|
@@ -332,8 +334,9 @@ async def process_audio(
|
|
332 |
request: Request,
|
333 |
file: UploadFile = File(..., description="Audio file to process"),
|
334 |
language: str = Query(..., enum=["kannada", "hindi", "tamil"], description="Language of the audio"),
|
335 |
-
|
336 |
):
|
|
|
337 |
logger.info("Processing audio processing request", extra={
|
338 |
"endpoint": "/v1/process_audio",
|
339 |
"filename": file.filename,
|
@@ -378,9 +381,9 @@ async def process_audio(
|
|
378 |
async def transcribe_audio(
|
379 |
file: UploadFile = File(..., description="Audio file to transcribe"),
|
380 |
language: str = Query(..., enum=["kannada", "hindi", "tamil"], description="Language of the audio"),
|
381 |
-
|
382 |
-
request: Request = None,
|
383 |
):
|
|
|
384 |
start_time = time()
|
385 |
try:
|
386 |
file_content = await file.read()
|
@@ -419,8 +422,9 @@ async def chat_v2(
|
|
419 |
request: Request,
|
420 |
prompt: str = Form(..., description="Text prompt for chat"),
|
421 |
image: UploadFile = File(default=None, description="Optional image to accompany the prompt"),
|
422 |
-
|
423 |
):
|
|
|
424 |
if not prompt:
|
425 |
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
|
426 |
|
@@ -473,8 +477,9 @@ class TranslationResponse(BaseModel):
|
|
473 |
})
|
474 |
async def translate(
|
475 |
request: TranslationRequest,
|
476 |
-
|
477 |
):
|
|
|
478 |
logger.info(f"Received translation request: {request.dict()}, user_id: {user_id}")
|
479 |
|
480 |
external_url = f"https://slabstech-dhwani-internal-api-server.hf.space/translate?src_lang={request.src_lang}&tgt_lang={request.tgt_lang}"
|
|
|
8 |
from fastapi import Depends, FastAPI, File, HTTPException, Query, Request, UploadFile, Form, Security
|
9 |
from fastapi.middleware.cors import CORSMiddleware
|
10 |
from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
|
11 |
+
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
12 |
from pydantic import BaseModel, field_validator, Field
|
13 |
from slowapi import Limiter
|
14 |
from slowapi.util import get_remote_address
|
|
|
16 |
from PIL import Image
|
17 |
|
18 |
# Import from auth.py
|
19 |
+
from utils.auth import get_current_user, get_current_user_with_admin, login, refresh_token, register, TokenResponse, Settings, LoginRequest, RegisterRequest, bearer_scheme
|
20 |
|
21 |
# Assuming these are in your project structure
|
22 |
from config.tts_config import SPEED, ResponseFormat, config as tts_config
|
|
|
24 |
|
25 |
settings = Settings()
|
26 |
|
|
|
|
|
|
|
27 |
# FastAPI app setup with enhanced docs
|
28 |
app = FastAPI(
|
29 |
title="Dhwani API",
|
30 |
description="A multilingual AI-powered API supporting Indian languages for chat, text-to-speech, audio processing, and transcription. "
|
31 |
"**Authentication Guide:** \n"
|
32 |
+
"1. Register a new user via `/v1/register` with a POST request containing `username` and `password` (requires admin access). \n"
|
33 |
"2. Obtain an access token by sending a POST request to `/v1/token` with `username` and `password`. \n"
|
34 |
+
"3. Click the 'Authorize' button (top-right), enter your access token (e.g., `your_access_token`) in the 'bearerAuth' field, and click 'Authorize'. \n"
|
35 |
+
"All protected endpoints require this token for access. Only the 'admin' user (default password: adminpass) can register new users.",
|
36 |
version="1.0.0",
|
37 |
redirect_slashes=False,
|
38 |
openapi_tags=[
|
|
|
42 |
{"name": "Audio", "description": "Audio processing and TTS endpoints"},
|
43 |
{"name": "Translation", "description": "Text translation endpoints"}
|
44 |
],
|
|
|
45 |
)
|
46 |
|
47 |
app.add_middleware(
|
|
|
166 |
200: {"description": "New token issued", "model": TokenResponse},
|
167 |
401: {"description": "Invalid or expired token"}
|
168 |
})
|
169 |
+
async def refresh(credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)):
|
170 |
+
return await refresh_token(credentials)
|
171 |
|
172 |
@app.post("/v1/register",
|
173 |
response_model=TokenResponse,
|
174 |
summary="Register New User",
|
175 |
+
description="Create a new user account and return an access token. Requires admin access (use 'admin' user with password 'adminpass' initially).",
|
176 |
tags=["Authentication"],
|
177 |
responses={
|
178 |
200: {"description": "User registered successfully", "model": TokenResponse},
|
179 |
+
400: {"description": "Username already exists"},
|
180 |
+
403: {"description": "Admin access required"}
|
181 |
})
|
182 |
+
async def register_user(
|
183 |
+
register_request: RegisterRequest,
|
184 |
+
current_user: str = Depends(get_current_user_with_admin) # Enforce admin-only access
|
185 |
+
):
|
186 |
+
return await register(register_request, current_user) # Pass current_user explicitly
|
187 |
|
188 |
@app.post("/v1/audio/speech",
|
189 |
summary="Generate Speech from Text",
|
|
|
200 |
async def generate_audio(
|
201 |
request: Request,
|
202 |
speech_request: SpeechRequest = Depends(),
|
203 |
+
credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme),
|
204 |
tts_service: TTSService = Depends(get_tts_service)
|
205 |
):
|
206 |
+
user_id = await get_current_user(credentials)
|
207 |
if not speech_request.input.strip():
|
208 |
raise HTTPException(status_code=400, detail="Input cannot be empty")
|
209 |
|
|
|
277 |
async def chat(
|
278 |
request: Request,
|
279 |
chat_request: ChatRequest,
|
280 |
+
credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)
|
281 |
):
|
282 |
+
user_id = await get_current_user(credentials)
|
283 |
if not chat_request.prompt:
|
284 |
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
|
285 |
logger.info(f"Received prompt: {chat_request.prompt}, src_lang: {chat_request.src_lang}, user_id: {user_id}")
|
|
|
334 |
request: Request,
|
335 |
file: UploadFile = File(..., description="Audio file to process"),
|
336 |
language: str = Query(..., enum=["kannada", "hindi", "tamil"], description="Language of the audio"),
|
337 |
+
credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)
|
338 |
):
|
339 |
+
user_id = await get_current_user(credentials)
|
340 |
logger.info("Processing audio processing request", extra={
|
341 |
"endpoint": "/v1/process_audio",
|
342 |
"filename": file.filename,
|
|
|
381 |
async def transcribe_audio(
|
382 |
file: UploadFile = File(..., description="Audio file to transcribe"),
|
383 |
language: str = Query(..., enum=["kannada", "hindi", "tamil"], description="Language of the audio"),
|
384 |
+
credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)
|
|
|
385 |
):
|
386 |
+
user_id = await get_current_user(credentials)
|
387 |
start_time = time()
|
388 |
try:
|
389 |
file_content = await file.read()
|
|
|
422 |
request: Request,
|
423 |
prompt: str = Form(..., description="Text prompt for chat"),
|
424 |
image: UploadFile = File(default=None, description="Optional image to accompany the prompt"),
|
425 |
+
credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)
|
426 |
):
|
427 |
+
user_id = await get_current_user(credentials)
|
428 |
if not prompt:
|
429 |
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
|
430 |
|
|
|
477 |
})
|
478 |
async def translate(
|
479 |
request: TranslationRequest,
|
480 |
+
credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)
|
481 |
):
|
482 |
+
user_id = await get_current_user(credentials)
|
483 |
logger.info(f"Received translation request: {request.dict()}, user_id: {user_id}")
|
484 |
|
485 |
external_url = f"https://slabstech-dhwani-internal-api-server.hf.space/translate?src_lang={request.src_lang}&tgt_lang={request.tgt_lang}"
|
src/server/utils/auth.py
CHANGED
@@ -1,11 +1,11 @@
|
|
1 |
import jwt
|
2 |
from datetime import datetime, timedelta
|
3 |
-
from fastapi.security import OAuth2PasswordBearer
|
4 |
from fastapi import HTTPException, status, Depends
|
|
|
5 |
from pydantic import BaseModel, Field
|
6 |
from pydantic_settings import BaseSettings
|
7 |
from config.logging_config import logger
|
8 |
-
from sqlalchemy import create_engine, Column, String
|
9 |
from sqlalchemy.ext.declarative import declarative_base
|
10 |
from sqlalchemy.orm import sessionmaker
|
11 |
from passlib.context import CryptContext
|
@@ -20,6 +20,7 @@ class User(Base):
|
|
20 |
__tablename__ = "users"
|
21 |
username = Column(String, primary_key=True, index=True)
|
22 |
password = Column(String) # Stores hashed passwords
|
|
|
23 |
|
24 |
Base.metadata.create_all(bind=engine)
|
25 |
|
@@ -31,7 +32,11 @@ def seed_initial_data():
|
|
31 |
db = SessionLocal()
|
32 |
if not db.query(User).filter_by(username="testuser").first():
|
33 |
hashed_password = pwd_context.hash("password123")
|
34 |
-
db.add(User(username="testuser", password=hashed_password))
|
|
|
|
|
|
|
|
|
35 |
db.commit()
|
36 |
db.close()
|
37 |
|
@@ -58,7 +63,8 @@ class Settings(BaseSettings):
|
|
58 |
settings = Settings()
|
59 |
logger.info(f"Loaded API_KEY_SECRET at startup: {settings.api_key_secret}")
|
60 |
|
61 |
-
|
|
|
62 |
|
63 |
class TokenPayload(BaseModel):
|
64 |
sub: str
|
@@ -84,7 +90,8 @@ async def create_access_token(user_id: str) -> str:
|
|
84 |
logger.info(f"Generated access token for user: {user_id}")
|
85 |
return token
|
86 |
|
87 |
-
async def get_current_user(
|
|
|
88 |
credentials_exception = HTTPException(
|
89 |
status_code=status.HTTP_401_UNAUTHORIZED,
|
90 |
detail="Invalid authentication credentials",
|
@@ -127,6 +134,16 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> str:
|
|
127 |
logger.error(f"Unexpected token validation error: {str(e)}")
|
128 |
raise credentials_exception
|
129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
async def login(login_request: LoginRequest) -> TokenResponse:
|
131 |
db = SessionLocal()
|
132 |
user = db.query(User).filter_by(username=login_request.username).first()
|
@@ -137,7 +154,7 @@ async def login(login_request: LoginRequest) -> TokenResponse:
|
|
137 |
token = await create_access_token(user_id=user.username)
|
138 |
return TokenResponse(access_token=token, token_type="bearer")
|
139 |
|
140 |
-
async def register(register_request: RegisterRequest) -> TokenResponse:
|
141 |
db = SessionLocal()
|
142 |
existing_user = db.query(User).filter_by(username=register_request.username).first()
|
143 |
if existing_user:
|
@@ -146,16 +163,16 @@ async def register(register_request: RegisterRequest) -> TokenResponse:
|
|
146 |
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Username already exists")
|
147 |
|
148 |
hashed_password = pwd_context.hash(register_request.password)
|
149 |
-
new_user = User(username=register_request.username, password=hashed_password)
|
150 |
db.add(new_user)
|
151 |
db.commit()
|
152 |
db.close()
|
153 |
|
154 |
token = await create_access_token(user_id=register_request.username)
|
155 |
-
logger.info(f"Registered and generated token for user: {register_request.username}")
|
156 |
return TokenResponse(access_token=token, token_type="bearer")
|
157 |
|
158 |
-
async def refresh_token(
|
159 |
-
user_id = await get_current_user(
|
160 |
new_token = await create_access_token(user_id=user_id)
|
161 |
return TokenResponse(access_token=new_token, token_type="bearer")
|
|
|
1 |
import jwt
|
2 |
from datetime import datetime, timedelta
|
|
|
3 |
from fastapi import HTTPException, status, Depends
|
4 |
+
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
5 |
from pydantic import BaseModel, Field
|
6 |
from pydantic_settings import BaseSettings
|
7 |
from config.logging_config import logger
|
8 |
+
from sqlalchemy import create_engine, Column, String, Boolean
|
9 |
from sqlalchemy.ext.declarative import declarative_base
|
10 |
from sqlalchemy.orm import sessionmaker
|
11 |
from passlib.context import CryptContext
|
|
|
20 |
__tablename__ = "users"
|
21 |
username = Column(String, primary_key=True, index=True)
|
22 |
password = Column(String) # Stores hashed passwords
|
23 |
+
is_admin = Column(Boolean, default=False) # New admin flag
|
24 |
|
25 |
Base.metadata.create_all(bind=engine)
|
26 |
|
|
|
32 |
db = SessionLocal()
|
33 |
if not db.query(User).filter_by(username="testuser").first():
|
34 |
hashed_password = pwd_context.hash("password123")
|
35 |
+
db.add(User(username="testuser", password=hashed_password, is_admin=False))
|
36 |
+
db.commit()
|
37 |
+
if not db.query(User).filter_by(username="admin").first():
|
38 |
+
hashed_password = pwd_context.hash("adminpass")
|
39 |
+
db.add(User(username="admin", password=hashed_password, is_admin=True))
|
40 |
db.commit()
|
41 |
db.close()
|
42 |
|
|
|
63 |
settings = Settings()
|
64 |
logger.info(f"Loaded API_KEY_SECRET at startup: {settings.api_key_secret}")
|
65 |
|
66 |
+
# Use HTTPBearer
|
67 |
+
bearer_scheme = HTTPBearer()
|
68 |
|
69 |
class TokenPayload(BaseModel):
|
70 |
sub: str
|
|
|
90 |
logger.info(f"Generated access token for user: {user_id}")
|
91 |
return token
|
92 |
|
93 |
+
async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)) -> str:
|
94 |
+
token = credentials.credentials
|
95 |
credentials_exception = HTTPException(
|
96 |
status_code=status.HTTP_401_UNAUTHORIZED,
|
97 |
detail="Invalid authentication credentials",
|
|
|
134 |
logger.error(f"Unexpected token validation error: {str(e)}")
|
135 |
raise credentials_exception
|
136 |
|
137 |
+
async def get_current_user_with_admin(credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)) -> str:
|
138 |
+
user_id = await get_current_user(credentials)
|
139 |
+
db = SessionLocal()
|
140 |
+
user = db.query(User).filter_by(username=user_id).first()
|
141 |
+
db.close()
|
142 |
+
if not user or not user.is_admin:
|
143 |
+
logger.warning(f"User {user_id} is not authorized as admin")
|
144 |
+
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin access required")
|
145 |
+
return user_id
|
146 |
+
|
147 |
async def login(login_request: LoginRequest) -> TokenResponse:
|
148 |
db = SessionLocal()
|
149 |
user = db.query(User).filter_by(username=login_request.username).first()
|
|
|
154 |
token = await create_access_token(user_id=user.username)
|
155 |
return TokenResponse(access_token=token, token_type="bearer")
|
156 |
|
157 |
+
async def register(register_request: RegisterRequest, current_user: str = Depends(get_current_user_with_admin)) -> TokenResponse:
|
158 |
db = SessionLocal()
|
159 |
existing_user = db.query(User).filter_by(username=register_request.username).first()
|
160 |
if existing_user:
|
|
|
163 |
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Username already exists")
|
164 |
|
165 |
hashed_password = pwd_context.hash(register_request.password)
|
166 |
+
new_user = User(username=register_request.username, password=hashed_password, is_admin=False)
|
167 |
db.add(new_user)
|
168 |
db.commit()
|
169 |
db.close()
|
170 |
|
171 |
token = await create_access_token(user_id=register_request.username)
|
172 |
+
logger.info(f"Registered and generated token for user: {register_request.username} by admin {current_user}")
|
173 |
return TokenResponse(access_token=token, token_type="bearer")
|
174 |
|
175 |
+
async def refresh_token(credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)) -> TokenResponse:
|
176 |
+
user_id = await get_current_user(credentials)
|
177 |
new_token = await create_access_token(user_id=user_id)
|
178 |
return TokenResponse(access_token=new_token, token_type="bearer")
|