Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 5,682 Bytes
844386f 9781b82 4335561 9781b82 4335561 8d6faeb 5a8554e 8d6faeb 5a8554e 8d6faeb 5a8554e 8d6faeb 5a8554e 8d6faeb 9781b82 5a8554e 9781b82 4335561 844386f 4335561 844386f 9781b82 844386f 9781b82 4335561 844386f 4335561 844386f 90b38cc 844386f 4335561 844386f 9781b82 844386f 4335561 844386f 4335561 8d6faeb 90b38cc 844386f 4335561 844386f 4335561 844386f 4335561 844386f 4335561 844386f 9781b82 90b38cc 8d6faeb 5a8554e 8d6faeb 90b38cc 8d6faeb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import jwt
from datetime import datetime, timedelta
from fastapi.security import OAuth2PasswordBearer
from fastapi import HTTPException, status, Depends
from pydantic import BaseModel, Field
from pydantic_settings import BaseSettings
from config.logging_config import logger
from sqlalchemy import create_engine, Column, String
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from passlib.context import CryptContext
# SQLite database setup
DATABASE_URL = "sqlite:///users.db"
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False}) # For SQLite threading
Base = declarative_base()
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# Password hashing setup
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
class User(Base):
__tablename__ = "users"
username = Column(String, primary_key=True, index=True)
password = Column(String) # Now stores hashed passwords
# Create the database tables
Base.metadata.create_all(bind=engine)
# Seed initial data (optional, for testing)
def seed_initial_data():
db = SessionLocal()
if not db.query(User).filter_by(username="testuser").first():
hashed_password = pwd_context.hash("password123")
db.add(User(username="testuser", password=hashed_password))
db.commit()
db.close()
seed_initial_data() # Run once at startup
class Settings(BaseSettings):
api_key_secret: str = Field(..., env="API_KEY_SECRET")
token_expiration_minutes: int = Field(30, env="TOKEN_EXPIRATION_MINUTES")
llm_model_name: str = "google/gemma-3-4b-it"
max_tokens: int = 512
host: str = "0.0.0.0"
port: int = 7860
chat_rate_limit: str = "100/minute"
speech_rate_limit: str = "5/minute"
external_tts_url: str = Field(..., env="EXTERNAL_TTS_URL")
external_asr_url: str = Field(..., env="EXTERNAL_ASR_URL")
external_text_gen_url: str = Field(..., env="EXTERNAL_TEXT_GEN_URL")
external_audio_proc_url: str = Field(..., env="EXTERNAL_AUDIO_PROC_URL")
class Config:
env_file = ".env"
env_file_encoding = "utf-8"
settings = Settings()
logger.info(f"Loaded API_KEY_SECRET at startup: {settings.api_key_secret}")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/v1/token")
class TokenPayload(BaseModel):
sub: str
exp: float
class TokenResponse(BaseModel):
access_token: str
token_type: str
class LoginRequest(BaseModel):
username: str
password: str
async def create_access_token(user_id: str) -> str:
expire = datetime.utcnow() + timedelta(minutes=settings.token_expiration_minutes)
payload = {"sub": user_id, "exp": expire.timestamp()}
logger.info(f"Signing token with API_KEY_SECRET: {settings.api_key_secret}")
token = jwt.encode(payload, settings.api_key_secret, algorithm="HS256")
logger.info(f"Generated access token for user: {user_id}")
return token
async def get_current_user(token: str = Depends(oauth2_scheme)) -> str:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
logger.info(f"Received token: {token}")
logger.info(f"Verifying token with API_KEY_SECRET: {settings.api_key_secret}")
payload = jwt.decode(token, settings.api_key_secret, algorithms=["HS256"], options={"verify_exp": False})
logger.info(f"Decoded payload: {payload}")
token_data = TokenPayload(**payload)
user_id = token_data.sub
db = SessionLocal()
user = db.query(User).filter_by(username=user_id).first()
db.close()
if user_id is None or not user:
logger.warning(f"Invalid or unknown user: {user_id}")
raise credentials_exception
current_time = datetime.utcnow().timestamp()
logger.info(f"Current time: {current_time}, Token exp: {token_data.exp}")
if current_time > token_data.exp:
logger.warning(f"Token expired: current_time={current_time}, exp={token_data.exp}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token has expired",
headers={"WWW-Authenticate": "Bearer"},
)
logger.info(f"Validated token for user: {user_id}")
return user_id
except jwt.InvalidSignatureError as e:
logger.error(f"Invalid signature error: {str(e)}")
raise credentials_exception
except jwt.InvalidTokenError as e:
logger.error(f"Other token error: {str(e)}")
raise credentials_exception
except Exception as e:
logger.error(f"Unexpected token validation error: {str(e)}")
raise credentials_exception
async def login(login_request: LoginRequest) -> TokenResponse:
db = SessionLocal()
user = db.query(User).filter_by(username=login_request.username).first()
db.close()
if not user or not pwd_context.verify(login_request.password, user.password):
logger.warning(f"Login failed for user: {login_request.username}")
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid username or password")
token = await create_access_token(user_id=user.username)
return TokenResponse(access_token=token, token_type="bearer")
async def refresh_token(token: str = Depends(oauth2_scheme)) -> TokenResponse:
user_id = await get_current_user(token)
new_token = await create_access_token(user_id=user_id)
return TokenResponse(access_token=new_token, token_type="bearer") |