Petr Tsvetkov
release
9513395
raw
history blame
1.76 kB
import pickle
import time
from grazie.api.client.chat.prompt import ChatPrompt
from grazie.api.client.endpoints import GrazieApiGatewayUrls
from grazie.api.client.gateway import AuthType, GrazieAgent, GrazieApiGatewayClient
from grazie.api.client.profiles import LLMProfile
import config
client = GrazieApiGatewayClient(
grazie_agent=GrazieAgent("grazie-toolformers", "v1.0"),
url=GrazieApiGatewayUrls.STAGING,
auth_type=AuthType.APPLICATION,
grazie_jwt_token=config.GRAZIE_API_JWT_TOKEN,
)
LLM_CACHE_FILE = config.CACHE_DIR / f"{config.LLM_MODEL}.cache.pkl"
LLM_CACHE = {}
LLM_CACHE_USED = {}
if not LLM_CACHE_FILE.exists():
with open(LLM_CACHE_FILE, "wb") as file:
pickle.dump(obj=LLM_CACHE, file=file)
with open(LLM_CACHE_FILE, "rb") as file:
LLM_CACHE = pickle.load(file=file)
def llm_request(prompt):
output = None
while output is None:
try:
output = client.chat(
chat=ChatPrompt().add_system("You are a helpful assistant.").add_user(prompt),
profile=LLMProfile(config.LLM_MODEL),
).content
except Exception:
time.sleep(config.GRAZIE_TIMEOUT_SEC)
assert output is not None
return output
def generate_for_prompt(prompt):
if prompt not in LLM_CACHE:
LLM_CACHE[prompt] = []
if prompt not in LLM_CACHE_USED:
LLM_CACHE_USED[prompt] = 0
while LLM_CACHE_USED[prompt] >= len(LLM_CACHE[prompt]):
new_response = llm_request(prompt)
LLM_CACHE[prompt].append(new_response)
with open(LLM_CACHE_FILE, "wb") as file:
pickle.dump(obj=LLM_CACHE, file=file)
result = LLM_CACHE[prompt][LLM_CACHE_USED[prompt]]
LLM_CACHE_USED[prompt] += 1
return result