File size: 6,518 Bytes
844386f
 
 
9781b82
4335561
9781b82
4335561
8d6faeb
 
 
5a8554e
 
8d6faeb
 
445a506
8d6faeb
 
 
 
 
 
445a506
8d6faeb
 
 
445a506
 
 
 
8d6faeb
 
 
5a8554e
 
8d6faeb
 
 
445a506
5a8554e
9781b82
4335561
 
844386f
 
 
 
 
 
4335561
 
 
 
844386f
9781b82
 
844386f
9781b82
 
4335561
844386f
 
 
 
4335561
 
844386f
 
 
 
 
90b38cc
 
 
 
445a506
 
 
 
844386f
 
 
4335561
844386f
 
 
9781b82
844386f
 
 
 
 
 
 
4335561
 
 
 
844386f
 
4335561
8d6faeb
 
 
 
90b38cc
844386f
4335561
 
 
 
 
844386f
 
 
 
 
4335561
844386f
 
4335561
 
 
 
 
844386f
 
4335561
844386f
9781b82
90b38cc
8d6faeb
 
 
5a8554e
8d6faeb
90b38cc
8d6faeb
 
 
445a506
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d6faeb
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import jwt
from datetime import datetime, timedelta
from fastapi.security import OAuth2PasswordBearer
from fastapi import HTTPException, status, Depends
from pydantic import BaseModel, Field
from pydantic_settings import BaseSettings
from config.logging_config import logger
from sqlalchemy import create_engine, Column, String
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from passlib.context import CryptContext

# SQLite database setup
DATABASE_URL = "sqlite:///users.db"
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
Base = declarative_base()
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

class User(Base):
    __tablename__ = "users"
    username = Column(String, primary_key=True, index=True)
    password = Column(String)  # Stores hashed passwords

Base.metadata.create_all(bind=engine)

# Password hashing
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

# Seed initial data (optional)
def seed_initial_data():
    db = SessionLocal()
    if not db.query(User).filter_by(username="testuser").first():
        hashed_password = pwd_context.hash("password123")
        db.add(User(username="testuser", password=hashed_password))
        db.commit()
    db.close()

seed_initial_data()

class Settings(BaseSettings):
    api_key_secret: str = Field(..., env="API_KEY_SECRET")
    token_expiration_minutes: int = Field(30, env="TOKEN_EXPIRATION_MINUTES")
    llm_model_name: str = "google/gemma-3-4b-it"
    max_tokens: int = 512
    host: str = "0.0.0.0"
    port: int = 7860
    chat_rate_limit: str = "100/minute"
    speech_rate_limit: str = "5/minute"
    external_tts_url: str = Field(..., env="EXTERNAL_TTS_URL")
    external_asr_url: str = Field(..., env="EXTERNAL_ASR_URL")
    external_text_gen_url: str = Field(..., env="EXTERNAL_TEXT_GEN_URL")
    external_audio_proc_url: str = Field(..., env="EXTERNAL_AUDIO_PROC_URL")

    class Config:
        env_file = ".env"
        env_file_encoding = "utf-8"

settings = Settings()
logger.info(f"Loaded API_KEY_SECRET at startup: {settings.api_key_secret}")

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/v1/token")

class TokenPayload(BaseModel):
    sub: str
    exp: float

class TokenResponse(BaseModel):
    access_token: str
    token_type: str

class LoginRequest(BaseModel):
    username: str
    password: str

class RegisterRequest(BaseModel):
    username: str
    password: str

async def create_access_token(user_id: str) -> str:
    expire = datetime.utcnow() + timedelta(minutes=settings.token_expiration_minutes)
    payload = {"sub": user_id, "exp": expire.timestamp()}
    logger.info(f"Signing token with API_KEY_SECRET: {settings.api_key_secret}")
    token = jwt.encode(payload, settings.api_key_secret, algorithm="HS256")
    logger.info(f"Generated access token for user: {user_id}")
    return token

async def get_current_user(token: str = Depends(oauth2_scheme)) -> str:
    credentials_exception = HTTPException(
        status_code=status.HTTP_401_UNAUTHORIZED,
        detail="Invalid authentication credentials",
        headers={"WWW-Authenticate": "Bearer"},
    )
    try:
        logger.info(f"Received token: {token}")
        logger.info(f"Verifying token with API_KEY_SECRET: {settings.api_key_secret}")
        payload = jwt.decode(token, settings.api_key_secret, algorithms=["HS256"], options={"verify_exp": False})
        logger.info(f"Decoded payload: {payload}")
        token_data = TokenPayload(**payload)
        user_id = token_data.sub
        
        db = SessionLocal()
        user = db.query(User).filter_by(username=user_id).first()
        db.close()
        if user_id is None or not user:
            logger.warning(f"Invalid or unknown user: {user_id}")
            raise credentials_exception
        
        current_time = datetime.utcnow().timestamp()
        logger.info(f"Current time: {current_time}, Token exp: {token_data.exp}")
        if current_time > token_data.exp:
            logger.warning(f"Token expired: current_time={current_time}, exp={token_data.exp}")
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail="Token has expired",
                headers={"WWW-Authenticate": "Bearer"},
            )
        
        logger.info(f"Validated token for user: {user_id}")
        return user_id
    except jwt.InvalidSignatureError as e:
        logger.error(f"Invalid signature error: {str(e)}")
        raise credentials_exception
    except jwt.InvalidTokenError as e:
        logger.error(f"Other token error: {str(e)}")
        raise credentials_exception
    except Exception as e:
        logger.error(f"Unexpected token validation error: {str(e)}")
        raise credentials_exception

async def login(login_request: LoginRequest) -> TokenResponse:
    db = SessionLocal()
    user = db.query(User).filter_by(username=login_request.username).first()
    db.close()
    if not user or not pwd_context.verify(login_request.password, user.password):
        logger.warning(f"Login failed for user: {login_request.username}")
        raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid username or password")
    token = await create_access_token(user_id=user.username)
    return TokenResponse(access_token=token, token_type="bearer")

async def register(register_request: RegisterRequest) -> TokenResponse:
    db = SessionLocal()
    existing_user = db.query(User).filter_by(username=register_request.username).first()
    if existing_user:
        db.close()
        logger.warning(f"Registration failed: Username {register_request.username} already exists")
        raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Username already exists")
    
    hashed_password = pwd_context.hash(register_request.password)
    new_user = User(username=register_request.username, password=hashed_password)
    db.add(new_user)
    db.commit()
    db.close()
    
    token = await create_access_token(user_id=register_request.username)
    logger.info(f"Registered and generated token for user: {register_request.username}")
    return TokenResponse(access_token=token, token_type="bearer")

async def refresh_token(token: str = Depends(oauth2_scheme)) -> TokenResponse:
    user_id = await get_current_user(token)
    new_token = await create_access_token(user_id=user_id)
    return TokenResponse(access_token=new_token, token_type="bearer")