sachin commited on
Commit
445a506
·
1 Parent(s): 5a8554e

improve-swagger

Browse files
Files changed (2) hide show
  1. src/server/main.py +181 -47
  2. src/server/utils/auth.py +29 -12
src/server/main.py CHANGED
@@ -3,18 +3,21 @@ import io
3
  from time import time
4
  from typing import List, Optional
5
  from abc import ABC, abstractmethod
 
 
6
 
7
  import uvicorn
8
  from fastapi import Depends, FastAPI, File, HTTPException, Query, Request, UploadFile, Form
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
11
- from pydantic import BaseModel, Field, field_validator
12
  from slowapi import Limiter
13
  from slowapi.util import get_remote_address
14
  import requests
15
  from PIL import Image
16
 
17
- from utils.auth import get_current_user, login, refresh_token, TokenResponse, Settings, LoginRequest
 
18
 
19
  # Assuming these are in your project structure
20
  from config.tts_config import SPEED, ResponseFormat, config as tts_config
@@ -22,13 +25,14 @@ from config.logging_config import logger
22
 
23
  settings = Settings()
24
 
25
- # FastAPI app setup
26
  app = FastAPI(
27
  title="Dhwani API",
28
- description="AI Chat API supporting Indian languages",
29
  version="1.0.0",
30
  redirect_slashes=False,
31
  )
 
32
  app.add_middleware(
33
  CORSMiddleware,
34
  allow_origins=["*"],
@@ -37,16 +41,16 @@ app.add_middleware(
37
  allow_headers=["*"],
38
  )
39
 
40
- limiter = Limiter(key_func=get_remote_address)
41
- app.state.limiter = limiter
42
 
43
  # Request/Response Models
44
  class SpeechRequest(BaseModel):
45
- input: str
46
- voice: str
47
- model: str
48
- response_format: ResponseFormat = tts_config.response_format
49
- speed: float = SPEED
50
 
51
  @field_validator("input")
52
  def input_must_be_valid(cls, v):
@@ -61,14 +65,34 @@ class SpeechRequest(BaseModel):
61
  raise ValueError(f"Response format must be one of {[fmt.value for fmt in supported_formats]}")
62
  return v
63
 
 
 
 
 
 
 
 
 
 
 
 
64
  class TranscriptionResponse(BaseModel):
65
- text: str
 
 
 
66
 
67
  class TextGenerationResponse(BaseModel):
68
- text: str
 
 
 
69
 
70
  class AudioProcessingResponse(BaseModel):
71
- result: str
 
 
 
72
 
73
  # TTS Service Interface
74
  class TTSService(ABC):
@@ -94,26 +118,68 @@ class ExternalTTSService(TTSService):
94
  def get_tts_service() -> TTSService:
95
  return ExternalTTSService()
96
 
 
 
 
 
 
 
 
 
97
 
 
 
 
 
 
 
98
 
99
- @app.post("/v1/token", response_model=TokenResponse)
 
 
 
 
 
 
 
 
100
  async def token(login_request: LoginRequest):
101
  return await login(login_request)
102
 
103
- @app.post("/v1/refresh", response_model=TokenResponse)
 
 
 
 
 
 
 
 
104
  async def refresh(token_response: TokenResponse = Depends(refresh_token)):
105
  return token_response
106
 
107
- @app.get("/v1/health")
108
- async def health_check():
109
- return {"status": "healthy", "model": settings.llm_model_name}
110
-
111
- @app.get("/")
112
- async def home():
113
- return RedirectResponse(url="/docs")
114
-
115
-
116
- @app.post("/v1/audio/speech")
 
 
 
 
 
 
 
 
 
 
 
 
117
  @limiter.limit(settings.speech_rate_limit)
118
  async def generate_audio(
119
  request: Request,
@@ -155,8 +221,8 @@ async def generate_audio(
155
  )
156
 
157
  class ChatRequest(BaseModel):
158
- prompt: str
159
- src_lang: str = "kan_Knda"
160
 
161
  @field_validator("prompt")
162
  def prompt_must_be_valid(cls, v):
@@ -164,10 +230,31 @@ class ChatRequest(BaseModel):
164
  raise ValueError("Prompt cannot exceed 1000 characters")
165
  return v.strip()
166
 
167
- class ChatResponse(BaseModel):
168
- response: str
 
 
 
 
 
169
 
170
- @app.post("/v1/chat", response_model=ChatResponse)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  @limiter.limit(settings.chat_rate_limit)
172
  async def chat(
173
  request: Request,
@@ -212,13 +299,22 @@ async def chat(
212
  logger.error(f"Error processing request: {str(e)}")
213
  raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
214
 
215
- @app.post("/v1/process_audio/", response_model=AudioProcessingResponse)
 
 
 
 
 
 
 
 
 
216
  @limiter.limit(settings.chat_rate_limit)
217
  async def process_audio(
218
- file: UploadFile = File(...),
219
- language: str = Query(..., enum=["kannada", "hindi", "tamil"]),
 
220
  user_id: str = Depends(get_current_user),
221
- request: Request = None,
222
  ):
223
  logger.info("Processing audio processing request", extra={
224
  "endpoint": "/v1/process_audio",
@@ -251,10 +347,18 @@ async def process_audio(
251
  logger.error(f"Audio processing request failed: {str(e)}")
252
  raise HTTPException(status_code=500, detail=f"Audio processing failed: {str(e)}")
253
 
254
- @app.post("/v1/transcribe/", response_model=TranscriptionResponse)
 
 
 
 
 
 
 
 
255
  async def transcribe_audio(
256
- file: UploadFile = File(...),
257
- language: str = Query(..., enum=["kannada", "hindi", "tamil"]),
258
  user_id: str = Depends(get_current_user),
259
  request: Request = None,
260
  ):
@@ -280,12 +384,21 @@ async def transcribe_audio(
280
  except requests.RequestException as e:
281
  raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
282
 
283
- @app.post("/v1/chat_v2", response_model=TranscriptionResponse)
 
 
 
 
 
 
 
 
 
284
  @limiter.limit(settings.chat_rate_limit)
285
  async def chat_v2(
286
  request: Request,
287
- prompt: str = Form(...),
288
- image: UploadFile = File(default=None),
289
  user_id: str = Depends(get_current_user)
290
  ):
291
  if not prompt:
@@ -308,14 +421,35 @@ async def chat_v2(
308
  raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
309
 
310
  class TranslationRequest(BaseModel):
311
- sentences: list[str]
312
- src_lang: str
313
- tgt_lang: str
 
 
 
 
 
 
 
 
 
314
 
315
  class TranslationResponse(BaseModel):
316
- translations: list[str]
317
-
318
- @app.post("/v1/translate", response_model=TranslationResponse)
 
 
 
 
 
 
 
 
 
 
 
 
319
  async def translate(
320
  request: TranslationRequest,
321
  user_id: str = Depends(get_current_user)
 
3
  from time import time
4
  from typing import List, Optional
5
  from abc import ABC, abstractmethod
6
+ from pydantic import BaseModel, Field
7
+
8
 
9
  import uvicorn
10
  from fastapi import Depends, FastAPI, File, HTTPException, Query, Request, UploadFile, Form
11
  from fastapi.middleware.cors import CORSMiddleware
12
  from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
13
+ from pydantic import BaseModel, field_validator
14
  from slowapi import Limiter
15
  from slowapi.util import get_remote_address
16
  import requests
17
  from PIL import Image
18
 
19
+ # Import from auth.py
20
+ from utils.auth import get_current_user, login, refresh_token, register, TokenResponse, Settings, LoginRequest, RegisterRequest
21
 
22
  # Assuming these are in your project structure
23
  from config.tts_config import SPEED, ResponseFormat, config as tts_config
 
25
 
26
  settings = Settings()
27
 
28
+ # FastAPI app setup with enhanced docs
29
  app = FastAPI(
30
  title="Dhwani API",
31
+ description="A multilingual AI-powered API supporting Indian languages for chat, text-to-speech, audio processing, and transcription.",
32
  version="1.0.0",
33
  redirect_slashes=False,
34
  )
35
+
36
  app.add_middleware(
37
  CORSMiddleware,
38
  allow_origins=["*"],
 
41
  allow_headers=["*"],
42
  )
43
 
44
+ # Rate limiting based on user_id
45
+ limiter = Limiter(key_func=lambda request: get_current_user(request.scope.get("route").dependencies))
46
 
47
  # Request/Response Models
48
  class SpeechRequest(BaseModel):
49
+ input: str = Field(..., description="Text to convert to speech (max 1000 characters)")
50
+ voice: str = Field(..., description="Voice identifier for the TTS service")
51
+ model: str = Field(..., description="TTS model to use")
52
+ response_format: ResponseFormat = Field(tts_config.response_format, description="Audio format: mp3, flac, or wav")
53
+ speed: float = Field(SPEED, description="Speech speed (default: 1.0)")
54
 
55
  @field_validator("input")
56
  def input_must_be_valid(cls, v):
 
65
  raise ValueError(f"Response format must be one of {[fmt.value for fmt in supported_formats]}")
66
  return v
67
 
68
+ class Config:
69
+ schema_extra = {
70
+ "example": {
71
+ "input": "Hello, how are you?",
72
+ "voice": "female-1",
73
+ "model": "tts-model-1",
74
+ "response_format": "mp3",
75
+ "speed": 1.0
76
+ }
77
+ }
78
+
79
  class TranscriptionResponse(BaseModel):
80
+ text: str = Field(..., description="Transcribed text from the audio")
81
+
82
+ class Config:
83
+ schema_extra = {"example": {"text": "Hello, how are you?"}}
84
 
85
  class TextGenerationResponse(BaseModel):
86
+ text: str = Field(..., description="Generated text response")
87
+
88
+ class Config:
89
+ schema_extra = {"example": {"text": "Hi there, I'm doing great!"}}
90
 
91
  class AudioProcessingResponse(BaseModel):
92
+ result: str = Field(..., description="Processed audio result")
93
+
94
+ class Config:
95
+ schema_extra = {"example": {"result": "Processed audio output"}}
96
 
97
  # TTS Service Interface
98
  class TTSService(ABC):
 
118
  def get_tts_service() -> TTSService:
119
  return ExternalTTSService()
120
 
121
+ # Endpoints with enhanced Swagger docs
122
+ @app.get("/v1/health",
123
+ summary="Check API Health",
124
+ description="Returns the health status of the API and the current model in use.",
125
+ tags=["Utility"],
126
+ response_model=dict)
127
+ async def health_check():
128
+ return {"status": "healthy", "model": settings.llm_model_name}
129
 
130
+ @app.get("/",
131
+ summary="Redirect to Docs",
132
+ description="Redirects to the Swagger UI documentation.",
133
+ tags=["Utility"])
134
+ async def home():
135
+ return RedirectResponse(url="/docs")
136
 
137
+ @app.post("/v1/token",
138
+ response_model=TokenResponse,
139
+ summary="User Login",
140
+ description="Authenticate a user with username and password to obtain an access token.",
141
+ tags=["Authentication"],
142
+ responses={
143
+ 200: {"description": "Successful login", "model": TokenResponse},
144
+ 401: {"description": "Invalid username or password"}
145
+ })
146
  async def token(login_request: LoginRequest):
147
  return await login(login_request)
148
 
149
+ @app.post("/v1/refresh",
150
+ response_model=TokenResponse,
151
+ summary="Refresh Access Token",
152
+ description="Generate a new access token using an existing valid token.",
153
+ tags=["Authentication"],
154
+ responses={
155
+ 200: {"description": "New token issued", "model": TokenResponse},
156
+ 401: {"description": "Invalid or expired token"}
157
+ })
158
  async def refresh(token_response: TokenResponse = Depends(refresh_token)):
159
  return token_response
160
 
161
+ @app.post("/v1/register",
162
+ response_model=TokenResponse,
163
+ summary="Register New User",
164
+ description="Create a new user account and return an access token.",
165
+ tags=["Authentication"],
166
+ responses={
167
+ 200: {"description": "User registered successfully", "model": TokenResponse},
168
+ 400: {"description": "Username already exists"}
169
+ })
170
+ async def register_user(register_request: RegisterRequest):
171
+ return await register(register_request)
172
+
173
+ @app.post("/v1/audio/speech",
174
+ summary="Generate Speech from Text",
175
+ description="Convert text to speech in the specified format using an external TTS service. Rate limited to 5 requests per minute per user.",
176
+ tags=["Audio"],
177
+ responses={
178
+ 200: {"description": "Audio stream", "content": {"audio/mp3": {"example": "Binary audio data"}}},
179
+ 400: {"description": "Invalid input"},
180
+ 429: {"description": "Rate limit exceeded"},
181
+ 504: {"description": "TTS service timeout"}
182
+ })
183
  @limiter.limit(settings.speech_rate_limit)
184
  async def generate_audio(
185
  request: Request,
 
221
  )
222
 
223
  class ChatRequest(BaseModel):
224
+ prompt: str = Field(..., description="Text prompt for chat (max 1000 characters)")
225
+ src_lang: str = Field("kan_Knda", description="Source language code (default: Kannada)")
226
 
227
  @field_validator("prompt")
228
  def prompt_must_be_valid(cls, v):
 
230
  raise ValueError("Prompt cannot exceed 1000 characters")
231
  return v.strip()
232
 
233
+ class Config:
234
+ schema_extra = {
235
+ "example": {
236
+ "prompt": "Hello, how are you?",
237
+ "src_lang": "kan_Knda"
238
+ }
239
+ }
240
 
241
+ class ChatResponse(BaseModel):
242
+ response: str = Field(..., description="Generated chat response")
243
+
244
+ class Config:
245
+ schema_extra = {"example": {"response": "Hi there, I'm doing great!"}}
246
+
247
+ @app.post("/v1/chat",
248
+ response_model=ChatResponse,
249
+ summary="Chat with AI",
250
+ description="Generate a chat response from a prompt in the specified language. Rate limited to 100 requests per minute per user.",
251
+ tags=["Chat"],
252
+ responses={
253
+ 200: {"description": "Chat response", "model": ChatResponse},
254
+ 400: {"description": "Invalid prompt"},
255
+ 429: {"description": "Rate limit exceeded"},
256
+ 504: {"description": "Chat service timeout"}
257
+ })
258
  @limiter.limit(settings.chat_rate_limit)
259
  async def chat(
260
  request: Request,
 
299
  logger.error(f"Error processing request: {str(e)}")
300
  raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
301
 
302
+ @app.post("/v1/process_audio/",
303
+ response_model=AudioProcessingResponse,
304
+ summary="Process Audio File",
305
+ description="Process an uploaded audio file in the specified language. Rate limited to 100 requests per minute per user.",
306
+ tags=["Audio"],
307
+ responses={
308
+ 200: {"description": "Processed result", "model": AudioProcessingResponse},
309
+ 429: {"description": "Rate limit exceeded"},
310
+ 504: {"description": "Audio processing timeout"}
311
+ })
312
  @limiter.limit(settings.chat_rate_limit)
313
  async def process_audio(
314
+ request: Request,
315
+ file: UploadFile = File(..., description="Audio file to process"),
316
+ language: str = Query(..., enum=["kannada", "hindi", "tamil"], description="Language of the audio"),
317
  user_id: str = Depends(get_current_user),
 
318
  ):
319
  logger.info("Processing audio processing request", extra={
320
  "endpoint": "/v1/process_audio",
 
347
  logger.error(f"Audio processing request failed: {str(e)}")
348
  raise HTTPException(status_code=500, detail=f"Audio processing failed: {str(e)}")
349
 
350
+ @app.post("/v1/transcribe/",
351
+ response_model=TranscriptionResponse,
352
+ summary="Transcribe Audio File",
353
+ description="Transcribe an uploaded audio file into text in the specified language.",
354
+ tags=["Audio"],
355
+ responses={
356
+ 200: {"description": "Transcription result", "model": TranscriptionResponse},
357
+ 504: {"description": "Transcription service timeout"}
358
+ })
359
  async def transcribe_audio(
360
+ file: UploadFile = File(..., description="Audio file to transcribe"),
361
+ language: str = Query(..., enum=["kannada", "hindi", "tamil"], description="Language of the audio"),
362
  user_id: str = Depends(get_current_user),
363
  request: Request = None,
364
  ):
 
384
  except requests.RequestException as e:
385
  raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
386
 
387
+ @app.post("/v1/chat_v2",
388
+ response_model=TranscriptionResponse,
389
+ summary="Chat with Image (V2)",
390
+ description="Generate a response from a text prompt and optional image. Rate limited to 100 requests per minute per user.",
391
+ tags=["Chat"],
392
+ responses={
393
+ 200: {"description": "Chat response", "model": TranscriptionResponse},
394
+ 400: {"description": "Invalid prompt"},
395
+ 429: {"description": "Rate limit exceeded"}
396
+ })
397
  @limiter.limit(settings.chat_rate_limit)
398
  async def chat_v2(
399
  request: Request,
400
+ prompt: str = Form(..., description="Text prompt for chat"),
401
+ image: UploadFile = File(default=None, description="Optional image to accompany the prompt"),
402
  user_id: str = Depends(get_current_user)
403
  ):
404
  if not prompt:
 
421
  raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
422
 
423
  class TranslationRequest(BaseModel):
424
+ sentences: List[str] = Field(..., description="List of sentences to translate")
425
+ src_lang: str = Field(..., description="Source language code")
426
+ tgt_lang: str = Field(..., description="Target language code")
427
+
428
+ class Config:
429
+ schema_extra = {
430
+ "example": {
431
+ "sentences": ["Hello", "How are you?"],
432
+ "src_lang": "en",
433
+ "tgt_lang": "kan_Knda"
434
+ }
435
+ }
436
 
437
  class TranslationResponse(BaseModel):
438
+ translations: List[str] = Field(..., description="Translated sentences")
439
+
440
+ class Config:
441
+ schema_extra = {"example": {"translations": ["ನಮಸ್ಕಾರ", "ನೀವು ಹೇಗಿದ್ದೀರಿ?"]}}
442
+
443
+ @app.post("/v1/translate",
444
+ response_model=TranslationResponse,
445
+ summary="Translate Text",
446
+ description="Translate a list of sentences from source to target language.",
447
+ tags=["Translation"],
448
+ responses={
449
+ 200: {"description": "Translation result", "model": TranslationResponse},
450
+ 500: {"description": "Translation service error"},
451
+ 504: {"description": "Translation service timeout"}
452
+ })
453
  async def translate(
454
  request: TranslationRequest,
455
  user_id: str = Depends(get_current_user)
src/server/utils/auth.py CHANGED
@@ -8,29 +8,25 @@ from config.logging_config import logger
8
  from sqlalchemy import create_engine, Column, String
9
  from sqlalchemy.ext.declarative import declarative_base
10
  from sqlalchemy.orm import sessionmaker
11
-
12
  from passlib.context import CryptContext
13
 
14
  # SQLite database setup
15
  DATABASE_URL = "sqlite:///users.db"
16
- engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False}) # For SQLite threading
17
  Base = declarative_base()
18
  SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
19
 
20
- # Password hashing setup
21
- pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
22
-
23
-
24
-
25
  class User(Base):
26
  __tablename__ = "users"
27
  username = Column(String, primary_key=True, index=True)
28
- password = Column(String) # Now stores hashed passwords
29
 
30
- # Create the database tables
31
  Base.metadata.create_all(bind=engine)
32
 
33
- # Seed initial data (optional, for testing)
 
 
 
34
  def seed_initial_data():
35
  db = SessionLocal()
36
  if not db.query(User).filter_by(username="testuser").first():
@@ -39,8 +35,7 @@ def seed_initial_data():
39
  db.commit()
40
  db.close()
41
 
42
- seed_initial_data() # Run once at startup
43
-
44
 
45
  class Settings(BaseSettings):
46
  api_key_secret: str = Field(..., env="API_KEY_SECRET")
@@ -77,6 +72,10 @@ class LoginRequest(BaseModel):
77
  username: str
78
  password: str
79
 
 
 
 
 
80
  async def create_access_token(user_id: str) -> str:
81
  expire = datetime.utcnow() + timedelta(minutes=settings.token_expiration_minutes)
82
  payload = {"sub": user_id, "exp": expire.timestamp()}
@@ -138,6 +137,24 @@ async def login(login_request: LoginRequest) -> TokenResponse:
138
  token = await create_access_token(user_id=user.username)
139
  return TokenResponse(access_token=token, token_type="bearer")
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  async def refresh_token(token: str = Depends(oauth2_scheme)) -> TokenResponse:
142
  user_id = await get_current_user(token)
143
  new_token = await create_access_token(user_id=user_id)
 
8
  from sqlalchemy import create_engine, Column, String
9
  from sqlalchemy.ext.declarative import declarative_base
10
  from sqlalchemy.orm import sessionmaker
 
11
  from passlib.context import CryptContext
12
 
13
  # SQLite database setup
14
  DATABASE_URL = "sqlite:///users.db"
15
+ engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
16
  Base = declarative_base()
17
  SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
18
 
 
 
 
 
 
19
  class User(Base):
20
  __tablename__ = "users"
21
  username = Column(String, primary_key=True, index=True)
22
+ password = Column(String) # Stores hashed passwords
23
 
 
24
  Base.metadata.create_all(bind=engine)
25
 
26
+ # Password hashing
27
+ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
28
+
29
+ # Seed initial data (optional)
30
  def seed_initial_data():
31
  db = SessionLocal()
32
  if not db.query(User).filter_by(username="testuser").first():
 
35
  db.commit()
36
  db.close()
37
 
38
+ seed_initial_data()
 
39
 
40
  class Settings(BaseSettings):
41
  api_key_secret: str = Field(..., env="API_KEY_SECRET")
 
72
  username: str
73
  password: str
74
 
75
+ class RegisterRequest(BaseModel):
76
+ username: str
77
+ password: str
78
+
79
  async def create_access_token(user_id: str) -> str:
80
  expire = datetime.utcnow() + timedelta(minutes=settings.token_expiration_minutes)
81
  payload = {"sub": user_id, "exp": expire.timestamp()}
 
137
  token = await create_access_token(user_id=user.username)
138
  return TokenResponse(access_token=token, token_type="bearer")
139
 
140
+ async def register(register_request: RegisterRequest) -> TokenResponse:
141
+ db = SessionLocal()
142
+ existing_user = db.query(User).filter_by(username=register_request.username).first()
143
+ if existing_user:
144
+ db.close()
145
+ logger.warning(f"Registration failed: Username {register_request.username} already exists")
146
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Username already exists")
147
+
148
+ hashed_password = pwd_context.hash(register_request.password)
149
+ new_user = User(username=register_request.username, password=hashed_password)
150
+ db.add(new_user)
151
+ db.commit()
152
+ db.close()
153
+
154
+ token = await create_access_token(user_id=register_request.username)
155
+ logger.info(f"Registered and generated token for user: {register_request.username}")
156
+ return TokenResponse(access_token=token, token_type="bearer")
157
+
158
  async def refresh_token(token: str = Depends(oauth2_scheme)) -> TokenResponse:
159
  user_id = await get_current_user(token)
160
  new_token = await create_access_token(user_id=user_id)