from flask import Flask, request, jsonify from flask_jwt_extended import JWTManager, create_access_token, get_jwt_identity, get_jwt from datetime import datetime, timedelta import hashlib import secrets import os def setup_secure_cookies(app: Flask): """Setup secure cookie configuration for the Flask app.""" # Improved environment detection for Hugging Face Spaces is_development = ( app.config.get('DEBUG', False) or app.config.get('ENV') == 'development' or app.config.get('ENVIRONMENT') == 'development' ) # Check if we're running in Hugging Face Spaces is_huggingface = os.environ.get('SPACE_ID') is not None @app.after_request def set_secure_cookies(response): """Set secure cookies for all responses.""" # Only set cookies for requests that might have JSON data (typically POST/PUT) if request.method in ['POST', 'PUT']: # Get token from request if available token = request.headers.get('Authorization') if token and token.startswith('Bearer '): token = token[7:] # Remove 'Bearer ' prefix # Determine cookie security settings secure_cookie = not is_development samesite_policy = 'Lax' if is_huggingface else 'Strict' # Set secure cookie for access token response.set_cookie( 'access_token', token, httponly=True, # Prevent XSS attacks secure=secure_cookie, # Send over HTTPS in production/HF Spaces samesite=samesite_policy, # Adjust for Hugging Face Spaces max_age=3600, # 1 hour (matches default JWT expiration) path='/' # Make cookie available across all paths ) # Safely check for rememberMe in JSON data remember_me = False try: if request.is_json: json_data = request.get_json(silent=True) if json_data and isinstance(json_data, dict): remember_me = json_data.get('rememberMe', False) except: # If there's any error parsing JSON, default to False remember_me = False # Set remember me cookie if requested if remember_me: response.set_cookie( 'refresh_token', secrets.token_urlsafe(32), httponly=True, secure=secure_cookie, samesite=samesite_policy, max_age=7*24*60*60, # 7 days path='/' # Make cookie available across all paths ) return response return app def configure_jwt_with_cookies(app: Flask): """Configure JWT to work with cookies.""" jwt = JWTManager(app) # Get allowed origins from CORS configuration allowed_origins = [ 'http://localhost:3000', 'http://localhost:5000', 'http://127.0.0.1:3000', 'http://127.0.0.1:5000', 'http://192.168.1.4:3000', 'https://zelyanoth-lin-cbfcff2.hf.space' ] @jwt.token_verification_loader def verify_token_on_refresh_callback(jwt_header, jwt_payload): """Verify token and refresh if needed.""" # This is a simplified version - in production, you'd check a refresh token return True @jwt.expired_token_loader def expired_token_callback(jwt_header, jwt_payload): """Handle expired tokens.""" # Clear cookies when token expires response = jsonify({'success': False, 'message': 'Token has expired'}) response.set_cookie('access_token', '', expires=0, path='/') response.set_cookie('refresh_token', '', expires=0, path='/') # Add CORS headers for all allowed origins for origin in allowed_origins: response.headers.add('Access-Control-Allow-Origin', origin) response.headers.add('Access-Control-Allow-Credentials', 'true') return response, 401 @jwt.invalid_token_loader def invalid_token_callback(error): """Handle invalid tokens.""" response = jsonify({'success': False, 'message': 'Invalid token'}) response.set_cookie('access_token', '', expires=0, path='/') response.set_cookie('refresh_token', '', expires=0, path='/') # Add CORS headers for all allowed origins for origin in allowed_origins: response.headers.add('Access-Control-Allow-Origin', origin) response.headers.add('Access-Control-Allow-Credentials', 'true') return response, 401 @jwt.unauthorized_loader def missing_token_callback(error): """Handle missing tokens.""" # Check if token is in cookies token = request.cookies.get('access_token') if token: # Add token to request headers and continue request.headers['Authorization'] = f'Bearer {token}' return None # Let the request continue response = jsonify({'success': False, 'message': 'Missing token'}) # Add CORS headers for all allowed origins for origin in allowed_origins: response.headers.add('Access-Control-Allow-Origin', origin) response.headers.add('Access-Control-Allow-Credentials', 'true') return response, 401 return jwt