""" Qdrant Semantic Cache implementation Has 4 methods: - set_cache - get_cache - async_set_cache - async_get_cache """ import ast import asyncio import json from typing import Any import litellm from litellm._logging import print_verbose from .base_cache import BaseCache class QdrantSemanticCache(BaseCache): def __init__( # noqa: PLR0915 self, qdrant_api_base=None, qdrant_api_key=None, collection_name=None, similarity_threshold=None, quantization_config=None, embedding_model="text-embedding-ada-002", host_type=None, ): import os from litellm.llms.custom_httpx.http_handler import ( _get_httpx_client, get_async_httpx_client, httpxSpecialProvider, ) from litellm.secret_managers.main import get_secret_str if collection_name is None: raise Exception("collection_name must be provided, passed None") self.collection_name = collection_name print_verbose( f"qdrant semantic-cache initializing COLLECTION - {self.collection_name}" ) if similarity_threshold is None: raise Exception("similarity_threshold must be provided, passed None") self.similarity_threshold = similarity_threshold self.embedding_model = embedding_model headers = {} # check if defined as os.environ/ variable if qdrant_api_base: if isinstance(qdrant_api_base, str) and qdrant_api_base.startswith( "os.environ/" ): qdrant_api_base = get_secret_str(qdrant_api_base) if qdrant_api_key: if isinstance(qdrant_api_key, str) and qdrant_api_key.startswith( "os.environ/" ): qdrant_api_key = get_secret_str(qdrant_api_key) qdrant_api_base = ( qdrant_api_base or os.getenv("QDRANT_URL") or os.getenv("QDRANT_API_BASE") ) qdrant_api_key = qdrant_api_key or os.getenv("QDRANT_API_KEY") headers = {"Content-Type": "application/json"} if qdrant_api_key: headers["api-key"] = qdrant_api_key if qdrant_api_base is None: raise ValueError("Qdrant url must be provided") self.qdrant_api_base = qdrant_api_base self.qdrant_api_key = qdrant_api_key print_verbose(f"qdrant semantic-cache qdrant_api_base: {self.qdrant_api_base}") self.headers = headers self.sync_client = _get_httpx_client() self.async_client = get_async_httpx_client( llm_provider=httpxSpecialProvider.Caching ) if quantization_config is None: print_verbose( "Quantization config is not provided. Default binary quantization will be used." ) collection_exists = self.sync_client.get( url=f"{self.qdrant_api_base}/collections/{self.collection_name}/exists", headers=self.headers, ) if collection_exists.status_code != 200: raise ValueError( f"Error from qdrant checking if /collections exist {collection_exists.text}" ) if collection_exists.json()["result"]["exists"]: collection_details = self.sync_client.get( url=f"{self.qdrant_api_base}/collections/{self.collection_name}", headers=self.headers, ) self.collection_info = collection_details.json() print_verbose( f"Collection already exists.\nCollection details:{self.collection_info}" ) else: if quantization_config is None or quantization_config == "binary": quantization_params = { "binary": { "always_ram": False, } } elif quantization_config == "scalar": quantization_params = { "scalar": {"type": "int8", "quantile": 0.99, "always_ram": False} } elif quantization_config == "product": quantization_params = { "product": {"compression": "x16", "always_ram": False} } else: raise Exception( "Quantization config must be one of 'scalar', 'binary' or 'product'" ) new_collection_status = self.sync_client.put( url=f"{self.qdrant_api_base}/collections/{self.collection_name}", json={ "vectors": {"size": 1536, "distance": "Cosine"}, "quantization_config": quantization_params, }, headers=self.headers, ) if new_collection_status.json()["result"]: collection_details = self.sync_client.get( url=f"{self.qdrant_api_base}/collections/{self.collection_name}", headers=self.headers, ) self.collection_info = collection_details.json() print_verbose( f"New collection created.\nCollection details:{self.collection_info}" ) else: raise Exception("Error while creating new collection") def _get_cache_logic(self, cached_response: Any): if cached_response is None: return cached_response try: cached_response = json.loads( cached_response ) # Convert string to dictionary except Exception: cached_response = ast.literal_eval(cached_response) return cached_response def set_cache(self, key, value, **kwargs): print_verbose(f"qdrant semantic-cache set_cache, kwargs: {kwargs}") import uuid # get the prompt messages = kwargs["messages"] prompt = "" for message in messages: prompt += message["content"] # create an embedding for prompt embedding_response = litellm.embedding( model=self.embedding_model, input=prompt, cache={"no-store": True, "no-cache": True}, ) # get the embedding embedding = embedding_response["data"][0]["embedding"] value = str(value) assert isinstance(value, str) data = { "points": [ { "id": str(uuid.uuid4()), "vector": embedding, "payload": { "text": prompt, "response": value, }, }, ] } self.sync_client.put( url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points", headers=self.headers, json=data, ) return def get_cache(self, key, **kwargs): print_verbose(f"sync qdrant semantic-cache get_cache, kwargs: {kwargs}") # get the messages messages = kwargs["messages"] prompt = "" for message in messages: prompt += message["content"] # convert to embedding embedding_response = litellm.embedding( model=self.embedding_model, input=prompt, cache={"no-store": True, "no-cache": True}, ) # get the embedding embedding = embedding_response["data"][0]["embedding"] data = { "vector": embedding, "params": { "quantization": { "ignore": False, "rescore": True, "oversampling": 3.0, } }, "limit": 1, "with_payload": True, } search_response = self.sync_client.post( url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points/search", headers=self.headers, json=data, ) results = search_response.json()["result"] if results is None: return None if isinstance(results, list): if len(results) == 0: return None similarity = results[0]["score"] cached_prompt = results[0]["payload"]["text"] # check similarity, if more than self.similarity_threshold, return results print_verbose( f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}" ) if similarity >= self.similarity_threshold: # cache hit ! cached_value = results[0]["payload"]["response"] print_verbose( f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}" ) return self._get_cache_logic(cached_response=cached_value) else: # cache miss ! return None pass async def async_set_cache(self, key, value, **kwargs): import uuid from litellm.proxy.proxy_server import llm_model_list, llm_router print_verbose(f"async qdrant semantic-cache set_cache, kwargs: {kwargs}") # get the prompt messages = kwargs["messages"] prompt = "" for message in messages: prompt += message["content"] # create an embedding for prompt router_model_names = ( [m["model_name"] for m in llm_model_list] if llm_model_list is not None else [] ) if llm_router is not None and self.embedding_model in router_model_names: user_api_key = kwargs.get("metadata", {}).get("user_api_key", "") embedding_response = await llm_router.aembedding( model=self.embedding_model, input=prompt, cache={"no-store": True, "no-cache": True}, metadata={ "user_api_key": user_api_key, "semantic-cache-embedding": True, "trace_id": kwargs.get("metadata", {}).get("trace_id", None), }, ) else: # convert to embedding embedding_response = await litellm.aembedding( model=self.embedding_model, input=prompt, cache={"no-store": True, "no-cache": True}, ) # get the embedding embedding = embedding_response["data"][0]["embedding"] value = str(value) assert isinstance(value, str) data = { "points": [ { "id": str(uuid.uuid4()), "vector": embedding, "payload": { "text": prompt, "response": value, }, }, ] } await self.async_client.put( url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points", headers=self.headers, json=data, ) return async def async_get_cache(self, key, **kwargs): print_verbose(f"async qdrant semantic-cache get_cache, kwargs: {kwargs}") from litellm.proxy.proxy_server import llm_model_list, llm_router # get the messages messages = kwargs["messages"] prompt = "" for message in messages: prompt += message["content"] router_model_names = ( [m["model_name"] for m in llm_model_list] if llm_model_list is not None else [] ) if llm_router is not None and self.embedding_model in router_model_names: user_api_key = kwargs.get("metadata", {}).get("user_api_key", "") embedding_response = await llm_router.aembedding( model=self.embedding_model, input=prompt, cache={"no-store": True, "no-cache": True}, metadata={ "user_api_key": user_api_key, "semantic-cache-embedding": True, "trace_id": kwargs.get("metadata", {}).get("trace_id", None), }, ) else: # convert to embedding embedding_response = await litellm.aembedding( model=self.embedding_model, input=prompt, cache={"no-store": True, "no-cache": True}, ) # get the embedding embedding = embedding_response["data"][0]["embedding"] data = { "vector": embedding, "params": { "quantization": { "ignore": False, "rescore": True, "oversampling": 3.0, } }, "limit": 1, "with_payload": True, } search_response = await self.async_client.post( url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points/search", headers=self.headers, json=data, ) results = search_response.json()["result"] if results is None: kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0 return None if isinstance(results, list): if len(results) == 0: kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0 return None similarity = results[0]["score"] cached_prompt = results[0]["payload"]["text"] # check similarity, if more than self.similarity_threshold, return results print_verbose( f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}" ) # update kwargs["metadata"] with similarity, don't rewrite the original metadata kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity if similarity >= self.similarity_threshold: # cache hit ! cached_value = results[0]["payload"]["response"] print_verbose( f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}" ) return self._get_cache_logic(cached_response=cached_value) else: # cache miss ! return None pass async def _collection_info(self): return self.collection_info async def async_set_cache_pipeline(self, cache_list, **kwargs): tasks = [] for val in cache_list: tasks.append(self.async_set_cache(val[0], val[1], **kwargs)) await asyncio.gather(*tasks)