TestLLM / litellm /caching /redis_semantic_cache.py
Raju2024's picture
Upload 1072 files
e3278e4 verified
raw
history blame
12.1 kB
"""
Redis Semantic Cache implementation
Has 4 methods:
- set_cache
- get_cache
- async_set_cache
- async_get_cache
"""
import ast
import asyncio
import json
from typing import Any
import litellm
from litellm._logging import print_verbose
from .base_cache import BaseCache
class RedisSemanticCache(BaseCache):
def __init__(
self,
host=None,
port=None,
password=None,
redis_url=None,
similarity_threshold=None,
use_async=False,
embedding_model="text-embedding-ada-002",
**kwargs,
):
from redisvl.index import SearchIndex
print_verbose(
"redis semantic-cache initializing INDEX - litellm_semantic_cache_index"
)
if similarity_threshold is None:
raise Exception("similarity_threshold must be provided, passed None")
self.similarity_threshold = similarity_threshold
self.embedding_model = embedding_model
schema = {
"index": {
"name": "litellm_semantic_cache_index",
"prefix": "litellm",
"storage_type": "hash",
},
"fields": {
"text": [{"name": "response"}],
"vector": [
{
"name": "litellm_embedding",
"dims": 1536,
"distance_metric": "cosine",
"algorithm": "flat",
"datatype": "float32",
}
],
},
}
if redis_url is None:
# if no url passed, check if host, port and password are passed, if not raise an Exception
if host is None or port is None or password is None:
# try checking env for host, port and password
import os
host = os.getenv("REDIS_HOST")
port = os.getenv("REDIS_PORT")
password = os.getenv("REDIS_PASSWORD")
if host is None or port is None or password is None:
raise Exception("Redis host, port, and password must be provided")
redis_url = "redis://:" + password + "@" + host + ":" + port
print_verbose(f"redis semantic-cache redis_url: {redis_url}")
if use_async is False:
self.index = SearchIndex.from_dict(schema)
self.index.connect(redis_url=redis_url)
try:
self.index.create(overwrite=False) # don't overwrite existing index
except Exception as e:
print_verbose(f"Got exception creating semantic cache index: {str(e)}")
elif use_async is True:
schema["index"]["name"] = "litellm_semantic_cache_index_async"
self.index = SearchIndex.from_dict(schema)
self.index.connect(redis_url=redis_url, use_async=True)
#
def _get_cache_logic(self, cached_response: Any):
"""
Common 'get_cache_logic' across sync + async redis client implementations
"""
if cached_response is None:
return cached_response
# check if cached_response is bytes
if isinstance(cached_response, bytes):
cached_response = cached_response.decode("utf-8")
try:
cached_response = json.loads(
cached_response
) # Convert string to dictionary
except Exception:
cached_response = ast.literal_eval(cached_response)
return cached_response
def set_cache(self, key, value, **kwargs):
import numpy as np
print_verbose(f"redis semantic-cache set_cache, kwargs: {kwargs}")
# get the prompt
messages = kwargs["messages"]
prompt = "".join(message["content"] for message in messages)
# create an embedding for prompt
embedding_response = litellm.embedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
)
# get the embedding
embedding = embedding_response["data"][0]["embedding"]
# make the embedding a numpy array, convert to bytes
embedding_bytes = np.array(embedding, dtype=np.float32).tobytes()
value = str(value)
assert isinstance(value, str)
new_data = [
{"response": value, "prompt": prompt, "litellm_embedding": embedding_bytes}
]
# Add more data
self.index.load(new_data)
return
def get_cache(self, key, **kwargs):
print_verbose(f"sync redis semantic-cache get_cache, kwargs: {kwargs}")
from redisvl.query import VectorQuery
# query
# get the messages
messages = kwargs["messages"]
prompt = "".join(message["content"] for message in messages)
# convert to embedding
embedding_response = litellm.embedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
)
# get the embedding
embedding = embedding_response["data"][0]["embedding"]
query = VectorQuery(
vector=embedding,
vector_field_name="litellm_embedding",
return_fields=["response", "prompt", "vector_distance"],
num_results=1,
)
results = self.index.query(query)
if results is None:
return None
if isinstance(results, list):
if len(results) == 0:
return None
vector_distance = results[0]["vector_distance"]
vector_distance = float(vector_distance)
similarity = 1 - vector_distance
cached_prompt = results[0]["prompt"]
# check similarity, if more than self.similarity_threshold, return results
print_verbose(
f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
)
if similarity > self.similarity_threshold:
# cache hit !
cached_value = results[0]["response"]
print_verbose(
f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
)
return self._get_cache_logic(cached_response=cached_value)
else:
# cache miss !
return None
pass
async def async_set_cache(self, key, value, **kwargs):
import numpy as np
from litellm.proxy.proxy_server import llm_model_list, llm_router
try:
await self.index.acreate(overwrite=False) # don't overwrite existing index
except Exception as e:
print_verbose(f"Got exception creating semantic cache index: {str(e)}")
print_verbose(f"async redis semantic-cache set_cache, kwargs: {kwargs}")
# get the prompt
messages = kwargs["messages"]
prompt = "".join(message["content"] for message in messages)
# create an embedding for prompt
router_model_names = (
[m["model_name"] for m in llm_model_list]
if llm_model_list is not None
else []
)
if llm_router is not None and self.embedding_model in router_model_names:
user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
embedding_response = await llm_router.aembedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
metadata={
"user_api_key": user_api_key,
"semantic-cache-embedding": True,
"trace_id": kwargs.get("metadata", {}).get("trace_id", None),
},
)
else:
# convert to embedding
embedding_response = await litellm.aembedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
)
# get the embedding
embedding = embedding_response["data"][0]["embedding"]
# make the embedding a numpy array, convert to bytes
embedding_bytes = np.array(embedding, dtype=np.float32).tobytes()
value = str(value)
assert isinstance(value, str)
new_data = [
{"response": value, "prompt": prompt, "litellm_embedding": embedding_bytes}
]
# Add more data
await self.index.aload(new_data)
return
async def async_get_cache(self, key, **kwargs):
print_verbose(f"async redis semantic-cache get_cache, kwargs: {kwargs}")
from redisvl.query import VectorQuery
from litellm.proxy.proxy_server import llm_model_list, llm_router
# query
# get the messages
messages = kwargs["messages"]
prompt = "".join(message["content"] for message in messages)
router_model_names = (
[m["model_name"] for m in llm_model_list]
if llm_model_list is not None
else []
)
if llm_router is not None and self.embedding_model in router_model_names:
user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
embedding_response = await llm_router.aembedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
metadata={
"user_api_key": user_api_key,
"semantic-cache-embedding": True,
"trace_id": kwargs.get("metadata", {}).get("trace_id", None),
},
)
else:
# convert to embedding
embedding_response = await litellm.aembedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
)
# get the embedding
embedding = embedding_response["data"][0]["embedding"]
query = VectorQuery(
vector=embedding,
vector_field_name="litellm_embedding",
return_fields=["response", "prompt", "vector_distance"],
)
results = await self.index.aquery(query)
if results is None:
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
return None
if isinstance(results, list):
if len(results) == 0:
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
return None
vector_distance = results[0]["vector_distance"]
vector_distance = float(vector_distance)
similarity = 1 - vector_distance
cached_prompt = results[0]["prompt"]
# check similarity, if more than self.similarity_threshold, return results
print_verbose(
f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
)
# update kwargs["metadata"] with similarity, don't rewrite the original metadata
kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity
if similarity > self.similarity_threshold:
# cache hit !
cached_value = results[0]["response"]
print_verbose(
f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
)
return self._get_cache_logic(cached_response=cached_value)
else:
# cache miss !
return None
pass
async def _index_info(self):
return await self.index.ainfo()
async def async_set_cache_pipeline(self, cache_list, **kwargs):
tasks = []
for val in cache_list:
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
await asyncio.gather(*tasks)