from flask import Flask, request, jsonify from flask_jwt_extended import JWTManager, create_access_token, get_jwt_identity, get_jwt from datetime import datetime, timedelta import hashlib import secrets def setup_secure_cookies(app: Flask): """Setup secure cookie configuration for the Flask app.""" # Check if we're in a development environment is_development = app.config.get('DEBUG', False) or app.config.get('ENV') == 'development' @app.after_request def set_secure_cookies(response): """Set secure cookies for all responses.""" # Only set cookies for requests that might have JSON data (typically POST/PUT) if request.method in ['POST', 'PUT']: # Get token from request if available token = request.headers.get('Authorization') if token and token.startswith('Bearer '): token = token[7:] # Remove 'Bearer ' prefix # Set secure cookie for access token response.set_cookie( 'access_token', token, httponly=True, # Prevent XSS attacks secure=not is_development, # Only send over HTTPS in production samesite='Strict', # Prevent CSRF attacks max_age=3600 # 1 hour (matches default JWT expiration) ) # Safely check for rememberMe in JSON data remember_me = False try: if request.is_json: json_data = request.get_json(silent=True) if json_data and isinstance(json_data, dict): remember_me = json_data.get('rememberMe', False) except: # If there's any error parsing JSON, default to False remember_me = False # Set remember me cookie if requested if remember_me: response.set_cookie( 'refresh_token', secrets.token_urlsafe(32), httponly=True, secure=not is_development, samesite='Strict', max_age=7*24*60*60 # 7 days ) return response return app def configure_jwt_with_cookies(app: Flask): """Configure JWT to work with cookies.""" jwt = JWTManager(app) @jwt.token_verification_loader def verify_token_on_refresh_callback(jwt_header, jwt_payload): """Verify token and refresh if needed.""" # This is a simplified version - in production, you'd check a refresh token return True @jwt.expired_token_loader def expired_token_callback(jwt_header, jwt_payload): """Handle expired tokens.""" # Clear cookies when token expires response = jsonify({'success': False, 'message': 'Token has expired'}) response.set_cookie('access_token', '', expires=0) response.set_cookie('refresh_token', '', expires=0) # Add CORS headers response.headers.add('Access-Control-Allow-Origin', 'http://localhost:3000') response.headers.add('Access-Control-Allow-Credentials', 'true') return response, 401 @jwt.invalid_token_loader def invalid_token_callback(error): """Handle invalid tokens.""" response = jsonify({'success': False, 'message': 'Invalid token'}) response.set_cookie('access_token', '', expires=0) response.set_cookie('refresh_token', '', expires=0) # Add CORS headers response.headers.add('Access-Control-Allow-Origin', 'http://localhost:3000') response.headers.add('Access-Control-Allow-Credentials', 'true') return response, 401 @jwt.unauthorized_loader def missing_token_callback(error): """Handle missing tokens.""" # Check if token is in cookies token = request.cookies.get('access_token') if token: # Add token to request headers and continue request.headers['Authorization'] = f'Bearer {token}' return None # Let the request continue response = jsonify({'success': False, 'message': 'Missing token'}) # Add CORS headers response.headers.add('Access-Control-Allow-Origin', 'http://localhost:3000') response.headers.add('Access-Control-Allow-Credentials', 'true') return response, 401 return jwt