|
""" |
|
Redis Semantic Cache implementation |
|
|
|
Has 4 methods: |
|
- set_cache |
|
- get_cache |
|
- async_set_cache |
|
- async_get_cache |
|
""" |
|
|
|
import ast |
|
import asyncio |
|
import json |
|
from typing import Any |
|
|
|
import litellm |
|
from litellm._logging import print_verbose |
|
|
|
from .base_cache import BaseCache |
|
|
|
|
|
class RedisSemanticCache(BaseCache): |
|
def __init__( |
|
self, |
|
host=None, |
|
port=None, |
|
password=None, |
|
redis_url=None, |
|
similarity_threshold=None, |
|
use_async=False, |
|
embedding_model="text-embedding-ada-002", |
|
**kwargs, |
|
): |
|
from redisvl.index import SearchIndex |
|
|
|
print_verbose( |
|
"redis semantic-cache initializing INDEX - litellm_semantic_cache_index" |
|
) |
|
if similarity_threshold is None: |
|
raise Exception("similarity_threshold must be provided, passed None") |
|
self.similarity_threshold = similarity_threshold |
|
self.embedding_model = embedding_model |
|
schema = { |
|
"index": { |
|
"name": "litellm_semantic_cache_index", |
|
"prefix": "litellm", |
|
"storage_type": "hash", |
|
}, |
|
"fields": { |
|
"text": [{"name": "response"}], |
|
"vector": [ |
|
{ |
|
"name": "litellm_embedding", |
|
"dims": 1536, |
|
"distance_metric": "cosine", |
|
"algorithm": "flat", |
|
"datatype": "float32", |
|
} |
|
], |
|
}, |
|
} |
|
if redis_url is None: |
|
|
|
if host is None or port is None or password is None: |
|
|
|
import os |
|
|
|
host = os.getenv("REDIS_HOST") |
|
port = os.getenv("REDIS_PORT") |
|
password = os.getenv("REDIS_PASSWORD") |
|
if host is None or port is None or password is None: |
|
raise Exception("Redis host, port, and password must be provided") |
|
|
|
redis_url = "redis://:" + password + "@" + host + ":" + port |
|
print_verbose(f"redis semantic-cache redis_url: {redis_url}") |
|
if use_async is False: |
|
self.index = SearchIndex.from_dict(schema) |
|
self.index.connect(redis_url=redis_url) |
|
try: |
|
self.index.create(overwrite=False) |
|
except Exception as e: |
|
print_verbose(f"Got exception creating semantic cache index: {str(e)}") |
|
elif use_async is True: |
|
schema["index"]["name"] = "litellm_semantic_cache_index_async" |
|
self.index = SearchIndex.from_dict(schema) |
|
self.index.connect(redis_url=redis_url, use_async=True) |
|
|
|
|
|
def _get_cache_logic(self, cached_response: Any): |
|
""" |
|
Common 'get_cache_logic' across sync + async redis client implementations |
|
""" |
|
if cached_response is None: |
|
return cached_response |
|
|
|
|
|
if isinstance(cached_response, bytes): |
|
cached_response = cached_response.decode("utf-8") |
|
|
|
try: |
|
cached_response = json.loads( |
|
cached_response |
|
) |
|
except Exception: |
|
cached_response = ast.literal_eval(cached_response) |
|
return cached_response |
|
|
|
def set_cache(self, key, value, **kwargs): |
|
import numpy as np |
|
|
|
print_verbose(f"redis semantic-cache set_cache, kwargs: {kwargs}") |
|
|
|
|
|
messages = kwargs["messages"] |
|
prompt = "".join(message["content"] for message in messages) |
|
|
|
|
|
embedding_response = litellm.embedding( |
|
model=self.embedding_model, |
|
input=prompt, |
|
cache={"no-store": True, "no-cache": True}, |
|
) |
|
|
|
|
|
embedding = embedding_response["data"][0]["embedding"] |
|
|
|
|
|
embedding_bytes = np.array(embedding, dtype=np.float32).tobytes() |
|
value = str(value) |
|
assert isinstance(value, str) |
|
|
|
new_data = [ |
|
{"response": value, "prompt": prompt, "litellm_embedding": embedding_bytes} |
|
] |
|
|
|
|
|
self.index.load(new_data) |
|
|
|
return |
|
|
|
def get_cache(self, key, **kwargs): |
|
print_verbose(f"sync redis semantic-cache get_cache, kwargs: {kwargs}") |
|
from redisvl.query import VectorQuery |
|
|
|
|
|
|
|
messages = kwargs["messages"] |
|
prompt = "".join(message["content"] for message in messages) |
|
|
|
|
|
embedding_response = litellm.embedding( |
|
model=self.embedding_model, |
|
input=prompt, |
|
cache={"no-store": True, "no-cache": True}, |
|
) |
|
|
|
|
|
embedding = embedding_response["data"][0]["embedding"] |
|
|
|
query = VectorQuery( |
|
vector=embedding, |
|
vector_field_name="litellm_embedding", |
|
return_fields=["response", "prompt", "vector_distance"], |
|
num_results=1, |
|
) |
|
|
|
results = self.index.query(query) |
|
if results is None: |
|
return None |
|
if isinstance(results, list): |
|
if len(results) == 0: |
|
return None |
|
|
|
vector_distance = results[0]["vector_distance"] |
|
vector_distance = float(vector_distance) |
|
similarity = 1 - vector_distance |
|
cached_prompt = results[0]["prompt"] |
|
|
|
|
|
print_verbose( |
|
f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}" |
|
) |
|
if similarity > self.similarity_threshold: |
|
|
|
cached_value = results[0]["response"] |
|
print_verbose( |
|
f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}" |
|
) |
|
return self._get_cache_logic(cached_response=cached_value) |
|
else: |
|
|
|
return None |
|
|
|
pass |
|
|
|
async def async_set_cache(self, key, value, **kwargs): |
|
import numpy as np |
|
|
|
from litellm.proxy.proxy_server import llm_model_list, llm_router |
|
|
|
try: |
|
await self.index.acreate(overwrite=False) |
|
except Exception as e: |
|
print_verbose(f"Got exception creating semantic cache index: {str(e)}") |
|
print_verbose(f"async redis semantic-cache set_cache, kwargs: {kwargs}") |
|
|
|
|
|
messages = kwargs["messages"] |
|
prompt = "".join(message["content"] for message in messages) |
|
|
|
router_model_names = ( |
|
[m["model_name"] for m in llm_model_list] |
|
if llm_model_list is not None |
|
else [] |
|
) |
|
if llm_router is not None and self.embedding_model in router_model_names: |
|
user_api_key = kwargs.get("metadata", {}).get("user_api_key", "") |
|
embedding_response = await llm_router.aembedding( |
|
model=self.embedding_model, |
|
input=prompt, |
|
cache={"no-store": True, "no-cache": True}, |
|
metadata={ |
|
"user_api_key": user_api_key, |
|
"semantic-cache-embedding": True, |
|
"trace_id": kwargs.get("metadata", {}).get("trace_id", None), |
|
}, |
|
) |
|
else: |
|
|
|
embedding_response = await litellm.aembedding( |
|
model=self.embedding_model, |
|
input=prompt, |
|
cache={"no-store": True, "no-cache": True}, |
|
) |
|
|
|
|
|
embedding = embedding_response["data"][0]["embedding"] |
|
|
|
|
|
embedding_bytes = np.array(embedding, dtype=np.float32).tobytes() |
|
value = str(value) |
|
assert isinstance(value, str) |
|
|
|
new_data = [ |
|
{"response": value, "prompt": prompt, "litellm_embedding": embedding_bytes} |
|
] |
|
|
|
|
|
await self.index.aload(new_data) |
|
return |
|
|
|
async def async_get_cache(self, key, **kwargs): |
|
print_verbose(f"async redis semantic-cache get_cache, kwargs: {kwargs}") |
|
from redisvl.query import VectorQuery |
|
|
|
from litellm.proxy.proxy_server import llm_model_list, llm_router |
|
|
|
|
|
|
|
messages = kwargs["messages"] |
|
prompt = "".join(message["content"] for message in messages) |
|
|
|
router_model_names = ( |
|
[m["model_name"] for m in llm_model_list] |
|
if llm_model_list is not None |
|
else [] |
|
) |
|
if llm_router is not None and self.embedding_model in router_model_names: |
|
user_api_key = kwargs.get("metadata", {}).get("user_api_key", "") |
|
embedding_response = await llm_router.aembedding( |
|
model=self.embedding_model, |
|
input=prompt, |
|
cache={"no-store": True, "no-cache": True}, |
|
metadata={ |
|
"user_api_key": user_api_key, |
|
"semantic-cache-embedding": True, |
|
"trace_id": kwargs.get("metadata", {}).get("trace_id", None), |
|
}, |
|
) |
|
else: |
|
|
|
embedding_response = await litellm.aembedding( |
|
model=self.embedding_model, |
|
input=prompt, |
|
cache={"no-store": True, "no-cache": True}, |
|
) |
|
|
|
|
|
embedding = embedding_response["data"][0]["embedding"] |
|
|
|
query = VectorQuery( |
|
vector=embedding, |
|
vector_field_name="litellm_embedding", |
|
return_fields=["response", "prompt", "vector_distance"], |
|
) |
|
results = await self.index.aquery(query) |
|
if results is None: |
|
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0 |
|
return None |
|
if isinstance(results, list): |
|
if len(results) == 0: |
|
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0 |
|
return None |
|
|
|
vector_distance = results[0]["vector_distance"] |
|
vector_distance = float(vector_distance) |
|
similarity = 1 - vector_distance |
|
cached_prompt = results[0]["prompt"] |
|
|
|
|
|
print_verbose( |
|
f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}" |
|
) |
|
|
|
|
|
kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity |
|
|
|
if similarity > self.similarity_threshold: |
|
|
|
cached_value = results[0]["response"] |
|
print_verbose( |
|
f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}" |
|
) |
|
return self._get_cache_logic(cached_response=cached_value) |
|
else: |
|
|
|
return None |
|
pass |
|
|
|
async def _index_info(self): |
|
return await self.index.ainfo() |
|
|
|
async def async_set_cache_pipeline(self, cache_list, **kwargs): |
|
tasks = [] |
|
for val in cache_list: |
|
tasks.append(self.async_set_cache(val[0], val[1], **kwargs)) |
|
await asyncio.gather(*tasks) |
|
|