Spaces:
Build error
Build error
import sqlite3 | |
import uuid | |
import os | |
import logging | |
from datetime import datetime, timedelta | |
import hashlib # Use hashlib instead of jwt | |
from passlib.hash import bcrypt | |
from dotenv import load_dotenv | |
from fastapi import Depends, HTTPException | |
from fastapi.security import OAuth2PasswordBearer | |
from pydantic import BaseModel | |
from typing import Optional | |
from fastapi import HTTPException, status | |
import jwt | |
from jwt.exceptions import PyJWTError | |
import sqlite3 | |
# Load environment variables | |
load_dotenv() | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger('auth') | |
# Security configuration | |
SECRET_KEY = os.getenv("JWT_SECRET", "your-secret-key-for-development-only") | |
ALGORITHM = "HS256" | |
JWT_EXPIRATION_DELTA = timedelta(days=1) # Token valid for 1 day | |
# Database path from environment variable or default | |
# Fix the incorrect DB_PATH | |
DB_PATH = os.getenv("DB_PATH", os.path.join(os.path.dirname(__file__), "data/user_data.db")) | |
# FastAPI OAuth2 scheme | |
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") | |
# Pydantic models for FastAPI | |
class User(BaseModel): | |
id: str | |
email: str | |
subscription_tier: str = "free_tier" | |
subscription_expiry: Optional[datetime] = None | |
api_calls_remaining: int = 5 | |
last_reset_date: Optional[datetime] = None | |
class UserCreate(BaseModel): | |
email: str | |
password: str | |
class Token(BaseModel): | |
access_token: str | |
token_type: str | |
class TokenData(BaseModel): | |
user_id: Optional[str] = None | |
# Subscription tiers and limits | |
# Update the SUBSCRIPTION_TIERS dictionary | |
SUBSCRIPTION_TIERS = { | |
"free_tier": { | |
"price": 0, | |
"currency": "INR", | |
"features": ["basic_document_analysis", "basic_risk_assessment"], | |
"limits": { | |
"document_size_mb": 5, | |
"documents_per_month": 3, | |
"video_size_mb": 0, | |
"audio_size_mb": 0 | |
} | |
}, | |
"standard_tier": { | |
"price": 799, | |
"currency": "INR", | |
"features": ["basic_document_analysis", "basic_risk_assessment", "video_analysis", "audio_analysis", "chatbot"], | |
"limits": { | |
"document_size_mb": 20, | |
"documents_per_month": 20, | |
"video_size_mb": 100, | |
"audio_size_mb": 50 | |
} | |
}, | |
"premium_tier": { | |
"price": 1499, | |
"currency": "INR", | |
"features": ["basic_document_analysis", "basic_risk_assessment", "video_analysis", "audio_analysis", "chatbot", "detailed_risk_assessment", "contract_clause_analysis"], | |
"limits": { | |
"document_size_mb": 50, | |
"documents_per_month": 999999, # Unlimited | |
"video_size_mb": 500, | |
"audio_size_mb": 200 | |
} | |
} | |
} | |
# Database connection management | |
def get_db_connection(): | |
"""Create and return a database connection with proper error handling""" | |
try: | |
# Ensure the directory exists | |
db_dir = os.path.dirname(DB_PATH) | |
os.makedirs(db_dir, exist_ok=True) | |
conn = sqlite3.connect(DB_PATH) | |
conn.row_factory = sqlite3.Row # Return rows as dictionaries | |
return conn | |
except sqlite3.Error as e: | |
logger.error(f"Database connection error: {e}") | |
raise Exception(f"Database connection failed: {e}") | |
# Database setup | |
# In the init_auth_db function, update the CREATE TABLE statement to match our schema | |
def init_auth_db(): | |
"""Initialize the authentication database with required tables""" | |
try: | |
conn = get_db_connection() | |
c = conn.cursor() | |
# Create users table with the correct schema | |
c.execute(''' | |
CREATE TABLE IF NOT EXISTS users ( | |
id TEXT PRIMARY KEY, | |
email TEXT UNIQUE NOT NULL, | |
hashed_password TEXT NOT NULL, | |
password TEXT, | |
subscription_tier TEXT DEFAULT 'free_tier', | |
is_active BOOLEAN DEFAULT 1, | |
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
api_calls_remaining INTEGER DEFAULT 10, | |
last_reset_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
) | |
''') | |
# Create subscriptions table | |
c.execute(''' | |
CREATE TABLE IF NOT EXISTS subscriptions ( | |
id TEXT PRIMARY KEY, | |
user_id TEXT, | |
tier TEXT, | |
plan_id TEXT, | |
status TEXT, | |
created_at TIMESTAMP, | |
expires_at TIMESTAMP, | |
paypal_subscription_id TEXT, | |
FOREIGN KEY (user_id) REFERENCES users (id) | |
) | |
''') | |
# Create usage stats table | |
c.execute(''' | |
CREATE TABLE IF NOT EXISTS usage_stats ( | |
id TEXT PRIMARY KEY, | |
user_id TEXT, | |
month INTEGER, | |
year INTEGER, | |
analyses_used INTEGER, | |
FOREIGN KEY (user_id) REFERENCES users (id) | |
) | |
''') | |
# Create tokens table for refresh tokens | |
c.execute(''' | |
CREATE TABLE IF NOT EXISTS refresh_tokens ( | |
user_id TEXT, | |
token TEXT, | |
expires_at TIMESTAMP, | |
FOREIGN KEY (user_id) REFERENCES users (id) | |
) | |
''') | |
conn.commit() | |
logger.info("Database initialized successfully") | |
except Exception as e: | |
logger.error(f"Database initialization error: {e}") | |
raise | |
finally: | |
if conn: | |
conn.close() | |
# Initialize the database | |
init_auth_db() | |
# Password hashing with bcrypt | |
# Update the password hashing and verification functions to use a more reliable method | |
# Replace these functions | |
# Remove these conflicting functions | |
# def hash_password(password): | |
# """Hash a password using bcrypt""" | |
# return bcrypt.hash(password) | |
# | |
# def verify_password(plain_password, hashed_password): | |
# """Verify a password against its hash""" | |
# return bcrypt.verify(plain_password, hashed_password) | |
# Keep only these improved functions | |
def hash_password(password): | |
"""Hash a password using bcrypt""" | |
# Use a more direct approach to avoid bcrypt version issues | |
import bcrypt | |
# Convert password to bytes if it's not already | |
if isinstance(password, str): | |
password = password.encode('utf-8') | |
# Generate salt and hash | |
salt = bcrypt.gensalt() | |
hashed = bcrypt.hashpw(password, salt) | |
# Return as string for storage | |
return hashed.decode('utf-8') | |
def verify_password(plain_password, hashed_password): | |
"""Verify a password against its hash""" | |
import bcrypt | |
# Convert inputs to bytes if they're not already | |
if isinstance(plain_password, str): | |
plain_password = plain_password.encode('utf-8') | |
if isinstance(hashed_password, str): | |
hashed_password = hashed_password.encode('utf-8') | |
try: | |
# Use direct bcrypt verification | |
return bcrypt.checkpw(plain_password, hashed_password) | |
except Exception as e: | |
logger.error(f"Password verification error: {e}") | |
return False | |
# User registration | |
def register_user(email, password): | |
try: | |
conn = get_db_connection() | |
c = conn.cursor() | |
# Check if user already exists | |
c.execute("SELECT * FROM users WHERE email = ?", (email,)) | |
if c.fetchone(): | |
return False, "Email already registered" | |
# Create new user | |
user_id = str(uuid.uuid4()) | |
# Add more detailed logging | |
logger.info(f"Registering new user with email: {email}") | |
hashed_pw = hash_password(password) | |
logger.info(f"Password hashed successfully: {bool(hashed_pw)}") | |
c.execute(""" | |
INSERT INTO users | |
(id, email, hashed_password, subscription_tier, api_calls_remaining, last_reset_date) | |
VALUES (?, ?, ?, ?, ?, ?) | |
""", (user_id, email, hashed_pw, "free_tier", 5, datetime.now())) | |
conn.commit() | |
logger.info(f"User registered successfully: {email}") | |
# Verify the user was actually stored | |
c.execute("SELECT * FROM users WHERE email = ?", (email,)) | |
stored_user = c.fetchone() | |
logger.info(f"User verification after registration: {bool(stored_user)}") | |
access_token = create_access_token(user_id) | |
return True, { | |
"user_id": user_id, | |
"access_token": access_token, | |
"token_type": "bearer" | |
} | |
except Exception as e: | |
logger.error(f"User registration error: {e}") | |
return False, f"Registration failed: {str(e)}" | |
finally: | |
if conn: | |
conn.close() | |
# User login | |
# Fix the authenticate_user function | |
# In the authenticate_user function, update the password verification to use hashed_password | |
def authenticate_user(email, password): | |
"""Authenticate a user and return user data with tokens""" | |
try: | |
conn = get_db_connection() | |
c = conn.cursor() | |
# Get user by email | |
c.execute("SELECT * FROM users WHERE email = ? AND is_active = 1", (email,)) | |
user = c.fetchone() | |
if not user: | |
logger.warning(f"User not found: {email}") | |
return None | |
# Add debug logging for password verification | |
logger.info(f"Verifying password for user: {email}") | |
logger.info(f"Stored hashed password: {user['hashed_password'][:20]}...") | |
try: | |
# Check if password verification works | |
is_valid = verify_password(password, user['hashed_password']) | |
logger.info(f"Password verification result: {is_valid}") | |
if not is_valid: | |
logger.warning(f"Password verification failed for user: {email}") | |
return None | |
except Exception as e: | |
logger.error(f"Password verification error: {e}") | |
return None | |
# Update last login time if column exists | |
try: | |
c.execute("UPDATE users SET last_login = ? WHERE id = ?", | |
(datetime.now(), user['id'])) | |
conn.commit() | |
except sqlite3.OperationalError: | |
# last_login column might not exist | |
pass | |
# Convert sqlite3.Row to dict to use get() method | |
user_dict = dict(user) | |
# Create and return a User object | |
return User( | |
id=user_dict['id'], | |
email=user_dict['email'], | |
subscription_tier=user_dict.get('subscription_tier', 'free_tier'), | |
subscription_expiry=None, # Handle this properly if needed | |
api_calls_remaining=user_dict.get('api_calls_remaining', 5), | |
last_reset_date=user_dict.get('last_reset_date') | |
) | |
except Exception as e: | |
logger.error(f"Login error: {e}") | |
return None | |
finally: | |
if conn: | |
conn.close() | |
# Token generation and validation - completely replaced | |
def create_access_token(user_id): | |
"""Create a new access token for a user""" | |
try: | |
# Create a JWT token with user_id and expiration | |
expiration = datetime.now() + JWT_EXPIRATION_DELTA | |
# Create a token payload | |
payload = { | |
"sub": user_id, | |
"exp": expiration.timestamp() | |
} | |
# Generate the JWT token | |
token = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM) | |
logger.info(f"Created access token for user: {user_id}") | |
return token | |
except Exception as e: | |
logger.error(f"Token creation error: {e}") | |
return None | |
def update_auth_db_schema(): | |
"""Update the authentication database schema with any missing columns""" | |
try: | |
conn = get_db_connection() | |
c = conn.cursor() | |
# Check if tier column exists in subscriptions table | |
c.execute("PRAGMA table_info(subscriptions)") | |
columns = [column[1] for column in c.fetchall()] | |
# Add tier column if it doesn't exist | |
if "tier" not in columns: | |
logger.info("Adding 'tier' column to subscriptions table") | |
c.execute("ALTER TABLE subscriptions ADD COLUMN tier TEXT") | |
conn.commit() | |
logger.info("Database schema updated successfully") | |
conn.close() | |
except Exception as e: | |
logger.error(f"Database schema update error: {e}") | |
raise HTTPException( | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
detail=f"Database schema update error: {str(e)}" | |
) | |
# Add this to your get_current_user function | |
async def get_current_user(token: str = Depends(oauth2_scheme)): | |
credentials_exception = HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail="Could not validate credentials", | |
headers={"WWW-Authenticate": "Bearer"}, | |
) | |
try: | |
# Decode the JWT token | |
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) | |
user_id: str = payload.get("sub") | |
if user_id is None: | |
logger.error("Token missing 'sub' field") | |
raise credentials_exception | |
except Exception as e: | |
logger.error(f"Token validation error: {str(e)}") | |
raise credentials_exception | |
# Get user from database | |
conn = get_db_connection() | |
cursor = conn.cursor() | |
cursor.execute("SELECT id, email, subscription_tier, is_active FROM users WHERE id = ?", (user_id,)) | |
user_data = cursor.fetchone() | |
conn.close() | |
if user_data is None: | |
logger.error(f"User not found: {user_id}") | |
raise credentials_exception | |
user = User( | |
id=user_data[0], | |
email=user_data[1], | |
subscription_tier=user_data[2], | |
is_active=bool(user_data[3]) | |
) | |
return user | |
async def get_current_active_user(current_user: User = Depends(get_current_user)): | |
"""Get the current active user""" | |
return current_user | |
def create_user_subscription(email, tier): | |
"""Create a subscription for a user""" | |
try: | |
# Get user by email | |
conn = get_db_connection() | |
c = conn.cursor() | |
# Get user ID | |
c.execute("SELECT id FROM users WHERE email = ?", (email,)) | |
user_data = c.fetchone() | |
if not user_data: | |
return False, "User not found" | |
user_id = user_data['id'] | |
# Check if tier is valid | |
valid_tiers = ["standard_tier", "premium_tier"] | |
if tier not in valid_tiers: | |
return False, f"Invalid tier: {tier}. Must be one of {valid_tiers}" | |
# Create subscription | |
subscription_id = str(uuid.uuid4()) | |
created_at = datetime.now() | |
expires_at = created_at + timedelta(days=30) # 30-day subscription | |
# Insert subscription | |
c.execute(""" | |
INSERT INTO subscriptions | |
(id, user_id, tier, status, created_at, expires_at) | |
VALUES (?, ?, ?, ?, ?, ?) | |
""", (subscription_id, user_id, tier, "active", created_at, expires_at)) | |
# Update user's subscription tier | |
c.execute(""" | |
UPDATE users | |
SET subscription_tier = ? | |
WHERE id = ? | |
""", (tier, user_id)) | |
conn.commit() | |
return True, { | |
"id": subscription_id, | |
"user_id": user_id, | |
"tier": tier, | |
"status": "active", | |
"created_at": created_at.isoformat(), | |
"expires_at": expires_at.isoformat() | |
} | |
except Exception as e: | |
logger.error(f"Subscription creation error: {e}") | |
return False, f"Failed to create subscription: {str(e)}" | |
finally: | |
if conn: | |
conn.close() | |
def get_user(user_id: str): | |
"""Get user by ID""" | |
try: | |
conn = get_db_connection() | |
c = conn.cursor() | |
# Get user | |
c.execute("SELECT * FROM users WHERE id = ? AND is_active = 1", (user_id,)) | |
user_data = c.fetchone() | |
if not user_data: | |
return None | |
# Convert to User model | |
user_dict = dict(user_data) | |
# Handle datetime conversions if needed | |
if user_dict.get("subscription_expiry") and isinstance(user_dict["subscription_expiry"], str): | |
user_dict["subscription_expiry"] = datetime.fromisoformat(user_dict["subscription_expiry"]) | |
if user_dict.get("last_reset_date") and isinstance(user_dict["last_reset_date"], str): | |
user_dict["last_reset_date"] = datetime.fromisoformat(user_dict["last_reset_date"]) | |
return User( | |
id=user_dict['id'], | |
email=user_dict['email'], | |
subscription_tier=user_dict['subscription_tier'], | |
subscription_expiry=user_dict.get('subscription_expiry'), | |
api_calls_remaining=user_dict.get('api_calls_remaining', 5), | |
last_reset_date=user_dict.get('last_reset_date') | |
) | |
except Exception as e: | |
logger.error(f"Get user error: {e}") | |
return None | |
finally: | |
if conn: | |
conn.close() | |
def check_subscription_access(user: User, feature: str, file_size_mb: Optional[float] = None): | |
"""Check if the user has access to the requested feature and file size""" | |
# Check if subscription is expired | |
if user.subscription_tier != "free_tier" and user.subscription_expiry and user.subscription_expiry < datetime.now(): | |
# Downgrade to free tier if subscription expired | |
user.subscription_tier = "free_tier" | |
user.api_calls_remaining = SUBSCRIPTION_TIERS["free_tier"]["daily_api_calls"] | |
with get_db_connection() as conn: | |
c = conn.cursor() | |
c.execute(""" | |
UPDATE users | |
SET subscription_tier = ?, api_calls_remaining = ? | |
WHERE id = ? | |
""", (user.subscription_tier, user.api_calls_remaining, user.id)) | |
conn.commit() | |
# Reset API calls if needed | |
user = reset_api_calls_if_needed(user) | |
# Check if user has API calls remaining | |
if user.api_calls_remaining <= 0: | |
raise HTTPException( | |
status_code=429, | |
detail="API call limit reached for today. Please upgrade your subscription or try again tomorrow." | |
) | |
# Check if feature is available in user's subscription tier | |
tier_features = SUBSCRIPTION_TIERS[user.subscription_tier]["features"] | |
if feature not in tier_features: | |
raise HTTPException( | |
status_code=403, | |
detail=f"The {feature} feature is not available in your {user.subscription_tier} subscription. Please upgrade to access this feature." | |
) | |
# Check file size limit if applicable | |
if file_size_mb: | |
max_size = SUBSCRIPTION_TIERS[user.subscription_tier]["max_document_size_mb"] | |
if file_size_mb > max_size: | |
raise HTTPException( | |
status_code=413, | |
detail=f"File size exceeds the {max_size}MB limit for your {user.subscription_tier} subscription. Please upgrade or use a smaller file." | |
) | |
# Decrement API calls remaining | |
user.api_calls_remaining -= 1 | |
with get_db_connection() as conn: | |
c = conn.cursor() | |
c.execute(""" | |
UPDATE users | |
SET api_calls_remaining = ? | |
WHERE id = ? | |
""", (user.api_calls_remaining, user.id)) | |
conn.commit() | |
return True | |
def reset_api_calls_if_needed(user: User): | |
"""Reset API call counter if it's a new day""" | |
today = datetime.now().date() | |
if user.last_reset_date is None or user.last_reset_date.date() < today: | |
tier_limits = SUBSCRIPTION_TIERS[user.subscription_tier] | |
user.api_calls_remaining = tier_limits["daily_api_calls"] | |
user.last_reset_date = datetime.now() | |
# Update the user in the database | |
with get_db_connection() as conn: | |
c = conn.cursor() | |
c.execute(""" | |
UPDATE users | |
SET api_calls_remaining = ?, last_reset_date = ? | |
WHERE id = ? | |
""", (user.api_calls_remaining, user.last_reset_date, user.id)) | |
conn.commit() | |
return user | |
def login_user(email, password): | |
"""Login a user with email and password""" | |
try: | |
# Authenticate user | |
user = authenticate_user(email, password) | |
if not user: | |
return False, "Incorrect username or password" | |
# Create access token | |
access_token = create_access_token(user.id) | |
# Create refresh token | |
refresh_token = str(uuid.uuid4()) | |
expires_at = datetime.now() + timedelta(days=30) | |
# Store refresh token | |
conn = get_db_connection() | |
c = conn.cursor() | |
c.execute("INSERT INTO refresh_tokens VALUES (?, ?, ?)", | |
(user.id, refresh_token, expires_at)) | |
conn.commit() | |
# Get subscription info | |
c.execute("SELECT * FROM subscriptions WHERE user_id = ? AND status = 'active'", (user.id,)) | |
subscription = c.fetchone() | |
# Convert subscription to dict if it exists, otherwise set to None | |
subscription_dict = dict(subscription) if subscription else None | |
conn.close() | |
return True, { | |
"user_id": user.id, | |
"email": user.email, | |
"access_token": access_token, | |
"refresh_token": refresh_token, | |
"subscription": subscription_dict | |
} | |
except Exception as e: | |
logger.error(f"Login error: {e}") | |
return False, f"Login failed: {str(e)}" |