|  | import json | 
					
						
						|  | import logging | 
					
						
						|  | import random | 
					
						
						|  | import re | 
					
						
						|  | import string | 
					
						
						|  | import subprocess | 
					
						
						|  | import time | 
					
						
						|  | import uuid | 
					
						
						|  | from collections.abc import Generator | 
					
						
						|  | from datetime import datetime | 
					
						
						|  | from hashlib import sha256 | 
					
						
						|  | from typing import Any, Optional, Union | 
					
						
						|  | from zoneinfo import available_timezones | 
					
						
						|  |  | 
					
						
						|  | from flask import Response, stream_with_context | 
					
						
						|  | from flask_restful import fields | 
					
						
						|  |  | 
					
						
						|  | from configs import dify_config | 
					
						
						|  | from core.app.features.rate_limiting.rate_limit import RateLimitGenerator | 
					
						
						|  | from core.file import helpers as file_helpers | 
					
						
						|  | from extensions.ext_redis import redis_client | 
					
						
						|  | from models.account import Account | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def run(script): | 
					
						
						|  | return subprocess.getstatusoutput("source /root/.bashrc && " + script) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class AppIconUrlField(fields.Raw): | 
					
						
						|  | def output(self, key, obj): | 
					
						
						|  | if obj is None: | 
					
						
						|  | return None | 
					
						
						|  |  | 
					
						
						|  | from models.model import IconType | 
					
						
						|  |  | 
					
						
						|  | if obj.icon_type == IconType.IMAGE.value: | 
					
						
						|  | return file_helpers.get_signed_file_url(obj.icon) | 
					
						
						|  | return None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class TimestampField(fields.Raw): | 
					
						
						|  | def format(self, value) -> int: | 
					
						
						|  | return int(value.timestamp()) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def email(email): | 
					
						
						|  |  | 
					
						
						|  | pattern = r"^[\w\.!#$%&'*+\-/=?^_`{|}~]+@([\w-]+\.)+[\w-]{2,}$" | 
					
						
						|  |  | 
					
						
						|  | if re.match(pattern, email) is not None: | 
					
						
						|  | return email | 
					
						
						|  |  | 
					
						
						|  | error = "{email} is not a valid email.".format(email=email) | 
					
						
						|  | raise ValueError(error) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def uuid_value(value): | 
					
						
						|  | if value == "": | 
					
						
						|  | return str(value) | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | uuid_obj = uuid.UUID(value) | 
					
						
						|  | return str(uuid_obj) | 
					
						
						|  | except ValueError: | 
					
						
						|  | error = "{value} is not a valid uuid.".format(value=value) | 
					
						
						|  | raise ValueError(error) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def alphanumeric(value: str): | 
					
						
						|  |  | 
					
						
						|  | if re.match(r"^[a-zA-Z0-9_]+$", value): | 
					
						
						|  | return value | 
					
						
						|  |  | 
					
						
						|  | raise ValueError(f"{value} is not a valid alphanumeric value") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def timestamp_value(timestamp): | 
					
						
						|  | try: | 
					
						
						|  | int_timestamp = int(timestamp) | 
					
						
						|  | if int_timestamp < 0: | 
					
						
						|  | raise ValueError | 
					
						
						|  | return int_timestamp | 
					
						
						|  | except ValueError: | 
					
						
						|  | error = "{timestamp} is not a valid timestamp.".format(timestamp=timestamp) | 
					
						
						|  | raise ValueError(error) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class StrLen: | 
					
						
						|  | """Restrict input to an integer in a range (inclusive)""" | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, max_length, argument="argument"): | 
					
						
						|  | self.max_length = max_length | 
					
						
						|  | self.argument = argument | 
					
						
						|  |  | 
					
						
						|  | def __call__(self, value): | 
					
						
						|  | length = len(value) | 
					
						
						|  | if length > self.max_length: | 
					
						
						|  | error = "Invalid {arg}: {val}. {arg} cannot exceed length {length}".format( | 
					
						
						|  | arg=self.argument, val=value, length=self.max_length | 
					
						
						|  | ) | 
					
						
						|  | raise ValueError(error) | 
					
						
						|  |  | 
					
						
						|  | return value | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class FloatRange: | 
					
						
						|  | """Restrict input to an float in a range (inclusive)""" | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, low, high, argument="argument"): | 
					
						
						|  | self.low = low | 
					
						
						|  | self.high = high | 
					
						
						|  | self.argument = argument | 
					
						
						|  |  | 
					
						
						|  | def __call__(self, value): | 
					
						
						|  | value = _get_float(value) | 
					
						
						|  | if value < self.low or value > self.high: | 
					
						
						|  | error = "Invalid {arg}: {val}. {arg} must be within the range {lo} - {hi}".format( | 
					
						
						|  | arg=self.argument, val=value, lo=self.low, hi=self.high | 
					
						
						|  | ) | 
					
						
						|  | raise ValueError(error) | 
					
						
						|  |  | 
					
						
						|  | return value | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class DatetimeString: | 
					
						
						|  | def __init__(self, format, argument="argument"): | 
					
						
						|  | self.format = format | 
					
						
						|  | self.argument = argument | 
					
						
						|  |  | 
					
						
						|  | def __call__(self, value): | 
					
						
						|  | try: | 
					
						
						|  | datetime.strptime(value, self.format) | 
					
						
						|  | except ValueError: | 
					
						
						|  | error = "Invalid {arg}: {val}. {arg} must be conform to the format {format}".format( | 
					
						
						|  | arg=self.argument, val=value, format=self.format | 
					
						
						|  | ) | 
					
						
						|  | raise ValueError(error) | 
					
						
						|  |  | 
					
						
						|  | return value | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def _get_float(value): | 
					
						
						|  | try: | 
					
						
						|  | return float(value) | 
					
						
						|  | except (TypeError, ValueError): | 
					
						
						|  | raise ValueError("{} is not a valid float".format(value)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def timezone(timezone_string): | 
					
						
						|  | if timezone_string and timezone_string in available_timezones(): | 
					
						
						|  | return timezone_string | 
					
						
						|  |  | 
					
						
						|  | error = "{timezone_string} is not a valid timezone.".format(timezone_string=timezone_string) | 
					
						
						|  | raise ValueError(error) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def generate_string(n): | 
					
						
						|  | letters_digits = string.ascii_letters + string.digits | 
					
						
						|  | result = "" | 
					
						
						|  | for i in range(n): | 
					
						
						|  | result += random.choice(letters_digits) | 
					
						
						|  |  | 
					
						
						|  | return result | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def extract_remote_ip(request) -> str: | 
					
						
						|  | if request.headers.get("CF-Connecting-IP"): | 
					
						
						|  | return request.headers.get("Cf-Connecting-Ip") | 
					
						
						|  | elif request.headers.getlist("X-Forwarded-For"): | 
					
						
						|  | return request.headers.getlist("X-Forwarded-For")[0] | 
					
						
						|  | else: | 
					
						
						|  | return request.remote_addr | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def generate_text_hash(text: str) -> str: | 
					
						
						|  | hash_text = str(text) + "None" | 
					
						
						|  | return sha256(hash_text.encode()).hexdigest() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def compact_generate_response(response: Union[dict, RateLimitGenerator]) -> Response: | 
					
						
						|  | if isinstance(response, dict): | 
					
						
						|  | return Response(response=json.dumps(response), status=200, mimetype="application/json") | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | def generate() -> Generator: | 
					
						
						|  | yield from response | 
					
						
						|  |  | 
					
						
						|  | return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class TokenManager: | 
					
						
						|  | @classmethod | 
					
						
						|  | def generate_token( | 
					
						
						|  | cls, | 
					
						
						|  | token_type: str, | 
					
						
						|  | account: Optional[Account] = None, | 
					
						
						|  | email: Optional[str] = None, | 
					
						
						|  | additional_data: Optional[dict] = None, | 
					
						
						|  | ) -> str: | 
					
						
						|  | if account is None and email is None: | 
					
						
						|  | raise ValueError("Account or email must be provided") | 
					
						
						|  |  | 
					
						
						|  | account_id = account.id if account else None | 
					
						
						|  | account_email = account.email if account else email | 
					
						
						|  |  | 
					
						
						|  | if account_id: | 
					
						
						|  | old_token = cls._get_current_token_for_account(account_id, token_type) | 
					
						
						|  | if old_token: | 
					
						
						|  | if isinstance(old_token, bytes): | 
					
						
						|  | old_token = old_token.decode("utf-8") | 
					
						
						|  | cls.revoke_token(old_token, token_type) | 
					
						
						|  |  | 
					
						
						|  | token = str(uuid.uuid4()) | 
					
						
						|  | token_data = {"account_id": account_id, "email": account_email, "token_type": token_type} | 
					
						
						|  | if additional_data: | 
					
						
						|  | token_data.update(additional_data) | 
					
						
						|  |  | 
					
						
						|  | expiry_minutes = dify_config.model_dump().get(f"{token_type.upper()}_TOKEN_EXPIRY_MINUTES") | 
					
						
						|  | token_key = cls._get_token_key(token, token_type) | 
					
						
						|  | expiry_time = int(expiry_minutes * 60) | 
					
						
						|  | redis_client.setex(token_key, expiry_time, json.dumps(token_data)) | 
					
						
						|  |  | 
					
						
						|  | if account_id: | 
					
						
						|  | cls._set_current_token_for_account(account.id, token, token_type, expiry_minutes) | 
					
						
						|  |  | 
					
						
						|  | return token | 
					
						
						|  |  | 
					
						
						|  | @classmethod | 
					
						
						|  | def _get_token_key(cls, token: str, token_type: str) -> str: | 
					
						
						|  | return f"{token_type}:token:{token}" | 
					
						
						|  |  | 
					
						
						|  | @classmethod | 
					
						
						|  | def revoke_token(cls, token: str, token_type: str): | 
					
						
						|  | token_key = cls._get_token_key(token, token_type) | 
					
						
						|  | redis_client.delete(token_key) | 
					
						
						|  |  | 
					
						
						|  | @classmethod | 
					
						
						|  | def get_token_data(cls, token: str, token_type: str) -> Optional[dict[str, Any]]: | 
					
						
						|  | key = cls._get_token_key(token, token_type) | 
					
						
						|  | token_data_json = redis_client.get(key) | 
					
						
						|  | if token_data_json is None: | 
					
						
						|  | logging.warning(f"{token_type} token {token} not found with key {key}") | 
					
						
						|  | return None | 
					
						
						|  | token_data = json.loads(token_data_json) | 
					
						
						|  | return token_data | 
					
						
						|  |  | 
					
						
						|  | @classmethod | 
					
						
						|  | def _get_current_token_for_account(cls, account_id: str, token_type: str) -> Optional[str]: | 
					
						
						|  | key = cls._get_account_token_key(account_id, token_type) | 
					
						
						|  | current_token = redis_client.get(key) | 
					
						
						|  | return current_token | 
					
						
						|  |  | 
					
						
						|  | @classmethod | 
					
						
						|  | def _set_current_token_for_account( | 
					
						
						|  | cls, account_id: str, token: str, token_type: str, expiry_hours: Union[int, float] | 
					
						
						|  | ): | 
					
						
						|  | key = cls._get_account_token_key(account_id, token_type) | 
					
						
						|  | expiry_time = int(expiry_hours * 60 * 60) | 
					
						
						|  | redis_client.setex(key, expiry_time, token) | 
					
						
						|  |  | 
					
						
						|  | @classmethod | 
					
						
						|  | def _get_account_token_key(cls, account_id: str, token_type: str) -> str: | 
					
						
						|  | return f"{token_type}:account:{account_id}" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class RateLimiter: | 
					
						
						|  | def __init__(self, prefix: str, max_attempts: int, time_window: int): | 
					
						
						|  | self.prefix = prefix | 
					
						
						|  | self.max_attempts = max_attempts | 
					
						
						|  | self.time_window = time_window | 
					
						
						|  |  | 
					
						
						|  | def _get_key(self, email: str) -> str: | 
					
						
						|  | return f"{self.prefix}:{email}" | 
					
						
						|  |  | 
					
						
						|  | def is_rate_limited(self, email: str) -> bool: | 
					
						
						|  | key = self._get_key(email) | 
					
						
						|  | current_time = int(time.time()) | 
					
						
						|  | window_start_time = current_time - self.time_window | 
					
						
						|  |  | 
					
						
						|  | redis_client.zremrangebyscore(key, "-inf", window_start_time) | 
					
						
						|  | attempts = redis_client.zcard(key) | 
					
						
						|  |  | 
					
						
						|  | if attempts and int(attempts) >= self.max_attempts: | 
					
						
						|  | return True | 
					
						
						|  | return False | 
					
						
						|  |  | 
					
						
						|  | def increment_rate_limit(self, email: str): | 
					
						
						|  | key = self._get_key(email) | 
					
						
						|  | current_time = int(time.time()) | 
					
						
						|  |  | 
					
						
						|  | redis_client.zadd(key, {current_time: current_time}) | 
					
						
						|  | redis_client.expire(key, self.time_window * 2) | 
					
						
						|  |  |