Spaces:
Runtime error
Runtime error
File size: 1,764 Bytes
9513395 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import pickle
import time
from grazie.api.client.chat.prompt import ChatPrompt
from grazie.api.client.endpoints import GrazieApiGatewayUrls
from grazie.api.client.gateway import AuthType, GrazieAgent, GrazieApiGatewayClient
from grazie.api.client.profiles import LLMProfile
import config
client = GrazieApiGatewayClient(
grazie_agent=GrazieAgent("grazie-toolformers", "v1.0"),
url=GrazieApiGatewayUrls.STAGING,
auth_type=AuthType.APPLICATION,
grazie_jwt_token=config.GRAZIE_API_JWT_TOKEN,
)
LLM_CACHE_FILE = config.CACHE_DIR / f"{config.LLM_MODEL}.cache.pkl"
LLM_CACHE = {}
LLM_CACHE_USED = {}
if not LLM_CACHE_FILE.exists():
with open(LLM_CACHE_FILE, "wb") as file:
pickle.dump(obj=LLM_CACHE, file=file)
with open(LLM_CACHE_FILE, "rb") as file:
LLM_CACHE = pickle.load(file=file)
def llm_request(prompt):
output = None
while output is None:
try:
output = client.chat(
chat=ChatPrompt().add_system("You are a helpful assistant.").add_user(prompt),
profile=LLMProfile(config.LLM_MODEL),
).content
except Exception:
time.sleep(config.GRAZIE_TIMEOUT_SEC)
assert output is not None
return output
def generate_for_prompt(prompt):
if prompt not in LLM_CACHE:
LLM_CACHE[prompt] = []
if prompt not in LLM_CACHE_USED:
LLM_CACHE_USED[prompt] = 0
while LLM_CACHE_USED[prompt] >= len(LLM_CACHE[prompt]):
new_response = llm_request(prompt)
LLM_CACHE[prompt].append(new_response)
with open(LLM_CACHE_FILE, "wb") as file:
pickle.dump(obj=LLM_CACHE, file=file)
result = LLM_CACHE[prompt][LLM_CACHE_USED[prompt]]
LLM_CACHE_USED[prompt] += 1
return result
|