Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
sachin
commited on
Commit
·
445a506
1
Parent(s):
5a8554e
improve-swagger
Browse files- src/server/main.py +181 -47
- src/server/utils/auth.py +29 -12
src/server/main.py
CHANGED
@@ -3,18 +3,21 @@ import io
|
|
3 |
from time import time
|
4 |
from typing import List, Optional
|
5 |
from abc import ABC, abstractmethod
|
|
|
|
|
6 |
|
7 |
import uvicorn
|
8 |
from fastapi import Depends, FastAPI, File, HTTPException, Query, Request, UploadFile, Form
|
9 |
from fastapi.middleware.cors import CORSMiddleware
|
10 |
from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
|
11 |
-
from pydantic import BaseModel,
|
12 |
from slowapi import Limiter
|
13 |
from slowapi.util import get_remote_address
|
14 |
import requests
|
15 |
from PIL import Image
|
16 |
|
17 |
-
from
|
|
|
18 |
|
19 |
# Assuming these are in your project structure
|
20 |
from config.tts_config import SPEED, ResponseFormat, config as tts_config
|
@@ -22,13 +25,14 @@ from config.logging_config import logger
|
|
22 |
|
23 |
settings = Settings()
|
24 |
|
25 |
-
# FastAPI app setup
|
26 |
app = FastAPI(
|
27 |
title="Dhwani API",
|
28 |
-
description="AI
|
29 |
version="1.0.0",
|
30 |
redirect_slashes=False,
|
31 |
)
|
|
|
32 |
app.add_middleware(
|
33 |
CORSMiddleware,
|
34 |
allow_origins=["*"],
|
@@ -37,16 +41,16 @@ app.add_middleware(
|
|
37 |
allow_headers=["*"],
|
38 |
)
|
39 |
|
40 |
-
|
41 |
-
|
42 |
|
43 |
# Request/Response Models
|
44 |
class SpeechRequest(BaseModel):
|
45 |
-
input: str
|
46 |
-
voice: str
|
47 |
-
model: str
|
48 |
-
response_format: ResponseFormat = tts_config.response_format
|
49 |
-
speed: float = SPEED
|
50 |
|
51 |
@field_validator("input")
|
52 |
def input_must_be_valid(cls, v):
|
@@ -61,14 +65,34 @@ class SpeechRequest(BaseModel):
|
|
61 |
raise ValueError(f"Response format must be one of {[fmt.value for fmt in supported_formats]}")
|
62 |
return v
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
class TranscriptionResponse(BaseModel):
|
65 |
-
text: str
|
|
|
|
|
|
|
66 |
|
67 |
class TextGenerationResponse(BaseModel):
|
68 |
-
text: str
|
|
|
|
|
|
|
69 |
|
70 |
class AudioProcessingResponse(BaseModel):
|
71 |
-
result: str
|
|
|
|
|
|
|
72 |
|
73 |
# TTS Service Interface
|
74 |
class TTSService(ABC):
|
@@ -94,26 +118,68 @@ class ExternalTTSService(TTSService):
|
|
94 |
def get_tts_service() -> TTSService:
|
95 |
return ExternalTTSService()
|
96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
-
@app.post("/v1/token",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
async def token(login_request: LoginRequest):
|
101 |
return await login(login_request)
|
102 |
|
103 |
-
@app.post("/v1/refresh",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
async def refresh(token_response: TokenResponse = Depends(refresh_token)):
|
105 |
return token_response
|
106 |
|
107 |
-
@app.
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
@limiter.limit(settings.speech_rate_limit)
|
118 |
async def generate_audio(
|
119 |
request: Request,
|
@@ -155,8 +221,8 @@ async def generate_audio(
|
|
155 |
)
|
156 |
|
157 |
class ChatRequest(BaseModel):
|
158 |
-
prompt: str
|
159 |
-
src_lang: str = "kan_Knda"
|
160 |
|
161 |
@field_validator("prompt")
|
162 |
def prompt_must_be_valid(cls, v):
|
@@ -164,10 +230,31 @@ class ChatRequest(BaseModel):
|
|
164 |
raise ValueError("Prompt cannot exceed 1000 characters")
|
165 |
return v.strip()
|
166 |
|
167 |
-
class
|
168 |
-
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
@limiter.limit(settings.chat_rate_limit)
|
172 |
async def chat(
|
173 |
request: Request,
|
@@ -212,13 +299,22 @@ async def chat(
|
|
212 |
logger.error(f"Error processing request: {str(e)}")
|
213 |
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
|
214 |
|
215 |
-
@app.post("/v1/process_audio/",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
@limiter.limit(settings.chat_rate_limit)
|
217 |
async def process_audio(
|
218 |
-
|
219 |
-
|
|
|
220 |
user_id: str = Depends(get_current_user),
|
221 |
-
request: Request = None,
|
222 |
):
|
223 |
logger.info("Processing audio processing request", extra={
|
224 |
"endpoint": "/v1/process_audio",
|
@@ -251,10 +347,18 @@ async def process_audio(
|
|
251 |
logger.error(f"Audio processing request failed: {str(e)}")
|
252 |
raise HTTPException(status_code=500, detail=f"Audio processing failed: {str(e)}")
|
253 |
|
254 |
-
@app.post("/v1/transcribe/",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
255 |
async def transcribe_audio(
|
256 |
-
file: UploadFile = File(
|
257 |
-
language: str = Query(..., enum=["kannada", "hindi", "tamil"]),
|
258 |
user_id: str = Depends(get_current_user),
|
259 |
request: Request = None,
|
260 |
):
|
@@ -280,12 +384,21 @@ async def transcribe_audio(
|
|
280 |
except requests.RequestException as e:
|
281 |
raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
|
282 |
|
283 |
-
@app.post("/v1/chat_v2",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
284 |
@limiter.limit(settings.chat_rate_limit)
|
285 |
async def chat_v2(
|
286 |
request: Request,
|
287 |
-
prompt: str = Form(
|
288 |
-
image: UploadFile = File(default=None),
|
289 |
user_id: str = Depends(get_current_user)
|
290 |
):
|
291 |
if not prompt:
|
@@ -308,14 +421,35 @@ async def chat_v2(
|
|
308 |
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
|
309 |
|
310 |
class TranslationRequest(BaseModel):
|
311 |
-
sentences:
|
312 |
-
src_lang: str
|
313 |
-
tgt_lang: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
314 |
|
315 |
class TranslationResponse(BaseModel):
|
316 |
-
translations:
|
317 |
-
|
318 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
319 |
async def translate(
|
320 |
request: TranslationRequest,
|
321 |
user_id: str = Depends(get_current_user)
|
|
|
3 |
from time import time
|
4 |
from typing import List, Optional
|
5 |
from abc import ABC, abstractmethod
|
6 |
+
from pydantic import BaseModel, Field
|
7 |
+
|
8 |
|
9 |
import uvicorn
|
10 |
from fastapi import Depends, FastAPI, File, HTTPException, Query, Request, UploadFile, Form
|
11 |
from fastapi.middleware.cors import CORSMiddleware
|
12 |
from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
|
13 |
+
from pydantic import BaseModel, field_validator
|
14 |
from slowapi import Limiter
|
15 |
from slowapi.util import get_remote_address
|
16 |
import requests
|
17 |
from PIL import Image
|
18 |
|
19 |
+
# Import from auth.py
|
20 |
+
from utils.auth import get_current_user, login, refresh_token, register, TokenResponse, Settings, LoginRequest, RegisterRequest
|
21 |
|
22 |
# Assuming these are in your project structure
|
23 |
from config.tts_config import SPEED, ResponseFormat, config as tts_config
|
|
|
25 |
|
26 |
settings = Settings()
|
27 |
|
28 |
+
# FastAPI app setup with enhanced docs
|
29 |
app = FastAPI(
|
30 |
title="Dhwani API",
|
31 |
+
description="A multilingual AI-powered API supporting Indian languages for chat, text-to-speech, audio processing, and transcription.",
|
32 |
version="1.0.0",
|
33 |
redirect_slashes=False,
|
34 |
)
|
35 |
+
|
36 |
app.add_middleware(
|
37 |
CORSMiddleware,
|
38 |
allow_origins=["*"],
|
|
|
41 |
allow_headers=["*"],
|
42 |
)
|
43 |
|
44 |
+
# Rate limiting based on user_id
|
45 |
+
limiter = Limiter(key_func=lambda request: get_current_user(request.scope.get("route").dependencies))
|
46 |
|
47 |
# Request/Response Models
|
48 |
class SpeechRequest(BaseModel):
|
49 |
+
input: str = Field(..., description="Text to convert to speech (max 1000 characters)")
|
50 |
+
voice: str = Field(..., description="Voice identifier for the TTS service")
|
51 |
+
model: str = Field(..., description="TTS model to use")
|
52 |
+
response_format: ResponseFormat = Field(tts_config.response_format, description="Audio format: mp3, flac, or wav")
|
53 |
+
speed: float = Field(SPEED, description="Speech speed (default: 1.0)")
|
54 |
|
55 |
@field_validator("input")
|
56 |
def input_must_be_valid(cls, v):
|
|
|
65 |
raise ValueError(f"Response format must be one of {[fmt.value for fmt in supported_formats]}")
|
66 |
return v
|
67 |
|
68 |
+
class Config:
|
69 |
+
schema_extra = {
|
70 |
+
"example": {
|
71 |
+
"input": "Hello, how are you?",
|
72 |
+
"voice": "female-1",
|
73 |
+
"model": "tts-model-1",
|
74 |
+
"response_format": "mp3",
|
75 |
+
"speed": 1.0
|
76 |
+
}
|
77 |
+
}
|
78 |
+
|
79 |
class TranscriptionResponse(BaseModel):
|
80 |
+
text: str = Field(..., description="Transcribed text from the audio")
|
81 |
+
|
82 |
+
class Config:
|
83 |
+
schema_extra = {"example": {"text": "Hello, how are you?"}}
|
84 |
|
85 |
class TextGenerationResponse(BaseModel):
|
86 |
+
text: str = Field(..., description="Generated text response")
|
87 |
+
|
88 |
+
class Config:
|
89 |
+
schema_extra = {"example": {"text": "Hi there, I'm doing great!"}}
|
90 |
|
91 |
class AudioProcessingResponse(BaseModel):
|
92 |
+
result: str = Field(..., description="Processed audio result")
|
93 |
+
|
94 |
+
class Config:
|
95 |
+
schema_extra = {"example": {"result": "Processed audio output"}}
|
96 |
|
97 |
# TTS Service Interface
|
98 |
class TTSService(ABC):
|
|
|
118 |
def get_tts_service() -> TTSService:
|
119 |
return ExternalTTSService()
|
120 |
|
121 |
+
# Endpoints with enhanced Swagger docs
|
122 |
+
@app.get("/v1/health",
|
123 |
+
summary="Check API Health",
|
124 |
+
description="Returns the health status of the API and the current model in use.",
|
125 |
+
tags=["Utility"],
|
126 |
+
response_model=dict)
|
127 |
+
async def health_check():
|
128 |
+
return {"status": "healthy", "model": settings.llm_model_name}
|
129 |
|
130 |
+
@app.get("/",
|
131 |
+
summary="Redirect to Docs",
|
132 |
+
description="Redirects to the Swagger UI documentation.",
|
133 |
+
tags=["Utility"])
|
134 |
+
async def home():
|
135 |
+
return RedirectResponse(url="/docs")
|
136 |
|
137 |
+
@app.post("/v1/token",
|
138 |
+
response_model=TokenResponse,
|
139 |
+
summary="User Login",
|
140 |
+
description="Authenticate a user with username and password to obtain an access token.",
|
141 |
+
tags=["Authentication"],
|
142 |
+
responses={
|
143 |
+
200: {"description": "Successful login", "model": TokenResponse},
|
144 |
+
401: {"description": "Invalid username or password"}
|
145 |
+
})
|
146 |
async def token(login_request: LoginRequest):
|
147 |
return await login(login_request)
|
148 |
|
149 |
+
@app.post("/v1/refresh",
|
150 |
+
response_model=TokenResponse,
|
151 |
+
summary="Refresh Access Token",
|
152 |
+
description="Generate a new access token using an existing valid token.",
|
153 |
+
tags=["Authentication"],
|
154 |
+
responses={
|
155 |
+
200: {"description": "New token issued", "model": TokenResponse},
|
156 |
+
401: {"description": "Invalid or expired token"}
|
157 |
+
})
|
158 |
async def refresh(token_response: TokenResponse = Depends(refresh_token)):
|
159 |
return token_response
|
160 |
|
161 |
+
@app.post("/v1/register",
|
162 |
+
response_model=TokenResponse,
|
163 |
+
summary="Register New User",
|
164 |
+
description="Create a new user account and return an access token.",
|
165 |
+
tags=["Authentication"],
|
166 |
+
responses={
|
167 |
+
200: {"description": "User registered successfully", "model": TokenResponse},
|
168 |
+
400: {"description": "Username already exists"}
|
169 |
+
})
|
170 |
+
async def register_user(register_request: RegisterRequest):
|
171 |
+
return await register(register_request)
|
172 |
+
|
173 |
+
@app.post("/v1/audio/speech",
|
174 |
+
summary="Generate Speech from Text",
|
175 |
+
description="Convert text to speech in the specified format using an external TTS service. Rate limited to 5 requests per minute per user.",
|
176 |
+
tags=["Audio"],
|
177 |
+
responses={
|
178 |
+
200: {"description": "Audio stream", "content": {"audio/mp3": {"example": "Binary audio data"}}},
|
179 |
+
400: {"description": "Invalid input"},
|
180 |
+
429: {"description": "Rate limit exceeded"},
|
181 |
+
504: {"description": "TTS service timeout"}
|
182 |
+
})
|
183 |
@limiter.limit(settings.speech_rate_limit)
|
184 |
async def generate_audio(
|
185 |
request: Request,
|
|
|
221 |
)
|
222 |
|
223 |
class ChatRequest(BaseModel):
|
224 |
+
prompt: str = Field(..., description="Text prompt for chat (max 1000 characters)")
|
225 |
+
src_lang: str = Field("kan_Knda", description="Source language code (default: Kannada)")
|
226 |
|
227 |
@field_validator("prompt")
|
228 |
def prompt_must_be_valid(cls, v):
|
|
|
230 |
raise ValueError("Prompt cannot exceed 1000 characters")
|
231 |
return v.strip()
|
232 |
|
233 |
+
class Config:
|
234 |
+
schema_extra = {
|
235 |
+
"example": {
|
236 |
+
"prompt": "Hello, how are you?",
|
237 |
+
"src_lang": "kan_Knda"
|
238 |
+
}
|
239 |
+
}
|
240 |
|
241 |
+
class ChatResponse(BaseModel):
|
242 |
+
response: str = Field(..., description="Generated chat response")
|
243 |
+
|
244 |
+
class Config:
|
245 |
+
schema_extra = {"example": {"response": "Hi there, I'm doing great!"}}
|
246 |
+
|
247 |
+
@app.post("/v1/chat",
|
248 |
+
response_model=ChatResponse,
|
249 |
+
summary="Chat with AI",
|
250 |
+
description="Generate a chat response from a prompt in the specified language. Rate limited to 100 requests per minute per user.",
|
251 |
+
tags=["Chat"],
|
252 |
+
responses={
|
253 |
+
200: {"description": "Chat response", "model": ChatResponse},
|
254 |
+
400: {"description": "Invalid prompt"},
|
255 |
+
429: {"description": "Rate limit exceeded"},
|
256 |
+
504: {"description": "Chat service timeout"}
|
257 |
+
})
|
258 |
@limiter.limit(settings.chat_rate_limit)
|
259 |
async def chat(
|
260 |
request: Request,
|
|
|
299 |
logger.error(f"Error processing request: {str(e)}")
|
300 |
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
|
301 |
|
302 |
+
@app.post("/v1/process_audio/",
|
303 |
+
response_model=AudioProcessingResponse,
|
304 |
+
summary="Process Audio File",
|
305 |
+
description="Process an uploaded audio file in the specified language. Rate limited to 100 requests per minute per user.",
|
306 |
+
tags=["Audio"],
|
307 |
+
responses={
|
308 |
+
200: {"description": "Processed result", "model": AudioProcessingResponse},
|
309 |
+
429: {"description": "Rate limit exceeded"},
|
310 |
+
504: {"description": "Audio processing timeout"}
|
311 |
+
})
|
312 |
@limiter.limit(settings.chat_rate_limit)
|
313 |
async def process_audio(
|
314 |
+
request: Request,
|
315 |
+
file: UploadFile = File(..., description="Audio file to process"),
|
316 |
+
language: str = Query(..., enum=["kannada", "hindi", "tamil"], description="Language of the audio"),
|
317 |
user_id: str = Depends(get_current_user),
|
|
|
318 |
):
|
319 |
logger.info("Processing audio processing request", extra={
|
320 |
"endpoint": "/v1/process_audio",
|
|
|
347 |
logger.error(f"Audio processing request failed: {str(e)}")
|
348 |
raise HTTPException(status_code=500, detail=f"Audio processing failed: {str(e)}")
|
349 |
|
350 |
+
@app.post("/v1/transcribe/",
|
351 |
+
response_model=TranscriptionResponse,
|
352 |
+
summary="Transcribe Audio File",
|
353 |
+
description="Transcribe an uploaded audio file into text in the specified language.",
|
354 |
+
tags=["Audio"],
|
355 |
+
responses={
|
356 |
+
200: {"description": "Transcription result", "model": TranscriptionResponse},
|
357 |
+
504: {"description": "Transcription service timeout"}
|
358 |
+
})
|
359 |
async def transcribe_audio(
|
360 |
+
file: UploadFile = File(..., description="Audio file to transcribe"),
|
361 |
+
language: str = Query(..., enum=["kannada", "hindi", "tamil"], description="Language of the audio"),
|
362 |
user_id: str = Depends(get_current_user),
|
363 |
request: Request = None,
|
364 |
):
|
|
|
384 |
except requests.RequestException as e:
|
385 |
raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
|
386 |
|
387 |
+
@app.post("/v1/chat_v2",
|
388 |
+
response_model=TranscriptionResponse,
|
389 |
+
summary="Chat with Image (V2)",
|
390 |
+
description="Generate a response from a text prompt and optional image. Rate limited to 100 requests per minute per user.",
|
391 |
+
tags=["Chat"],
|
392 |
+
responses={
|
393 |
+
200: {"description": "Chat response", "model": TranscriptionResponse},
|
394 |
+
400: {"description": "Invalid prompt"},
|
395 |
+
429: {"description": "Rate limit exceeded"}
|
396 |
+
})
|
397 |
@limiter.limit(settings.chat_rate_limit)
|
398 |
async def chat_v2(
|
399 |
request: Request,
|
400 |
+
prompt: str = Form(..., description="Text prompt for chat"),
|
401 |
+
image: UploadFile = File(default=None, description="Optional image to accompany the prompt"),
|
402 |
user_id: str = Depends(get_current_user)
|
403 |
):
|
404 |
if not prompt:
|
|
|
421 |
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
|
422 |
|
423 |
class TranslationRequest(BaseModel):
|
424 |
+
sentences: List[str] = Field(..., description="List of sentences to translate")
|
425 |
+
src_lang: str = Field(..., description="Source language code")
|
426 |
+
tgt_lang: str = Field(..., description="Target language code")
|
427 |
+
|
428 |
+
class Config:
|
429 |
+
schema_extra = {
|
430 |
+
"example": {
|
431 |
+
"sentences": ["Hello", "How are you?"],
|
432 |
+
"src_lang": "en",
|
433 |
+
"tgt_lang": "kan_Knda"
|
434 |
+
}
|
435 |
+
}
|
436 |
|
437 |
class TranslationResponse(BaseModel):
|
438 |
+
translations: List[str] = Field(..., description="Translated sentences")
|
439 |
+
|
440 |
+
class Config:
|
441 |
+
schema_extra = {"example": {"translations": ["ನಮಸ್ಕಾರ", "ನೀವು ಹೇಗಿದ್ದೀರಿ?"]}}
|
442 |
+
|
443 |
+
@app.post("/v1/translate",
|
444 |
+
response_model=TranslationResponse,
|
445 |
+
summary="Translate Text",
|
446 |
+
description="Translate a list of sentences from source to target language.",
|
447 |
+
tags=["Translation"],
|
448 |
+
responses={
|
449 |
+
200: {"description": "Translation result", "model": TranslationResponse},
|
450 |
+
500: {"description": "Translation service error"},
|
451 |
+
504: {"description": "Translation service timeout"}
|
452 |
+
})
|
453 |
async def translate(
|
454 |
request: TranslationRequest,
|
455 |
user_id: str = Depends(get_current_user)
|
src/server/utils/auth.py
CHANGED
@@ -8,29 +8,25 @@ from config.logging_config import logger
|
|
8 |
from sqlalchemy import create_engine, Column, String
|
9 |
from sqlalchemy.ext.declarative import declarative_base
|
10 |
from sqlalchemy.orm import sessionmaker
|
11 |
-
|
12 |
from passlib.context import CryptContext
|
13 |
|
14 |
# SQLite database setup
|
15 |
DATABASE_URL = "sqlite:///users.db"
|
16 |
-
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
|
17 |
Base = declarative_base()
|
18 |
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
19 |
|
20 |
-
# Password hashing setup
|
21 |
-
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
class User(Base):
|
26 |
__tablename__ = "users"
|
27 |
username = Column(String, primary_key=True, index=True)
|
28 |
-
password = Column(String) #
|
29 |
|
30 |
-
# Create the database tables
|
31 |
Base.metadata.create_all(bind=engine)
|
32 |
|
33 |
-
#
|
|
|
|
|
|
|
34 |
def seed_initial_data():
|
35 |
db = SessionLocal()
|
36 |
if not db.query(User).filter_by(username="testuser").first():
|
@@ -39,8 +35,7 @@ def seed_initial_data():
|
|
39 |
db.commit()
|
40 |
db.close()
|
41 |
|
42 |
-
seed_initial_data()
|
43 |
-
|
44 |
|
45 |
class Settings(BaseSettings):
|
46 |
api_key_secret: str = Field(..., env="API_KEY_SECRET")
|
@@ -77,6 +72,10 @@ class LoginRequest(BaseModel):
|
|
77 |
username: str
|
78 |
password: str
|
79 |
|
|
|
|
|
|
|
|
|
80 |
async def create_access_token(user_id: str) -> str:
|
81 |
expire = datetime.utcnow() + timedelta(minutes=settings.token_expiration_minutes)
|
82 |
payload = {"sub": user_id, "exp": expire.timestamp()}
|
@@ -138,6 +137,24 @@ async def login(login_request: LoginRequest) -> TokenResponse:
|
|
138 |
token = await create_access_token(user_id=user.username)
|
139 |
return TokenResponse(access_token=token, token_type="bearer")
|
140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
async def refresh_token(token: str = Depends(oauth2_scheme)) -> TokenResponse:
|
142 |
user_id = await get_current_user(token)
|
143 |
new_token = await create_access_token(user_id=user_id)
|
|
|
8 |
from sqlalchemy import create_engine, Column, String
|
9 |
from sqlalchemy.ext.declarative import declarative_base
|
10 |
from sqlalchemy.orm import sessionmaker
|
|
|
11 |
from passlib.context import CryptContext
|
12 |
|
13 |
# SQLite database setup
|
14 |
DATABASE_URL = "sqlite:///users.db"
|
15 |
+
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
|
16 |
Base = declarative_base()
|
17 |
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
18 |
|
|
|
|
|
|
|
|
|
|
|
19 |
class User(Base):
|
20 |
__tablename__ = "users"
|
21 |
username = Column(String, primary_key=True, index=True)
|
22 |
+
password = Column(String) # Stores hashed passwords
|
23 |
|
|
|
24 |
Base.metadata.create_all(bind=engine)
|
25 |
|
26 |
+
# Password hashing
|
27 |
+
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
28 |
+
|
29 |
+
# Seed initial data (optional)
|
30 |
def seed_initial_data():
|
31 |
db = SessionLocal()
|
32 |
if not db.query(User).filter_by(username="testuser").first():
|
|
|
35 |
db.commit()
|
36 |
db.close()
|
37 |
|
38 |
+
seed_initial_data()
|
|
|
39 |
|
40 |
class Settings(BaseSettings):
|
41 |
api_key_secret: str = Field(..., env="API_KEY_SECRET")
|
|
|
72 |
username: str
|
73 |
password: str
|
74 |
|
75 |
+
class RegisterRequest(BaseModel):
|
76 |
+
username: str
|
77 |
+
password: str
|
78 |
+
|
79 |
async def create_access_token(user_id: str) -> str:
|
80 |
expire = datetime.utcnow() + timedelta(minutes=settings.token_expiration_minutes)
|
81 |
payload = {"sub": user_id, "exp": expire.timestamp()}
|
|
|
137 |
token = await create_access_token(user_id=user.username)
|
138 |
return TokenResponse(access_token=token, token_type="bearer")
|
139 |
|
140 |
+
async def register(register_request: RegisterRequest) -> TokenResponse:
|
141 |
+
db = SessionLocal()
|
142 |
+
existing_user = db.query(User).filter_by(username=register_request.username).first()
|
143 |
+
if existing_user:
|
144 |
+
db.close()
|
145 |
+
logger.warning(f"Registration failed: Username {register_request.username} already exists")
|
146 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Username already exists")
|
147 |
+
|
148 |
+
hashed_password = pwd_context.hash(register_request.password)
|
149 |
+
new_user = User(username=register_request.username, password=hashed_password)
|
150 |
+
db.add(new_user)
|
151 |
+
db.commit()
|
152 |
+
db.close()
|
153 |
+
|
154 |
+
token = await create_access_token(user_id=register_request.username)
|
155 |
+
logger.info(f"Registered and generated token for user: {register_request.username}")
|
156 |
+
return TokenResponse(access_token=token, token_type="bearer")
|
157 |
+
|
158 |
async def refresh_token(token: str = Depends(oauth2_scheme)) -> TokenResponse:
|
159 |
user_id = await get_current_user(token)
|
160 |
new_token = await create_access_token(user_id=user_id)
|