simon_says_v2 / utils /openai.py
ericmichael's picture
first commit
0079b8d
raw
history blame
1.67 kB
from dotenv import load_dotenv
load_dotenv() # take environment variables from .env
import os
from openai import OpenAI
import vcr
client = OpenAI(
base_url=os.environ.get("OPENAI_API_BASE"),
api_key=os.environ.get("OPENAI_API_KEY"),
)
def completion(prompt, max_tokens=None, temperature=0):
_completion = client.completions.create(
model="gpt-3.5-turbo-instruct",
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature,
)
return _completion.choices[0].text.strip()
def chat_completion(message, model="gpt-3.5-turbo", prompt=None, temperature=0):
# Initialize the messages list
messages = []
# Add the prompt to the messages list
if prompt is not None:
messages += [{"role": "system", "content": prompt}]
if message is not None:
# Add the user's message to the messages list
messages += [{"role": "user", "content": message}]
# Make an API call to the OpenAI ChatCompletion endpoint with the model and messages
_completion = client.chat.completions.create(
model=model, messages=messages, temperature=temperature
)
# Extract and return the AI's response from the API response
return _completion.choices[0].message.content.strip()
def __vcr():
return vcr.VCR(
serializer="yaml",
cassette_library_dir="tests/fixtures/cassettes",
record_mode="new_episodes",
match_on=["uri", "method", "path", "query", "body"],
record_on_exception=False,
)
def cassette_for(name):
filename = name + ".yaml"
return __vcr().use_cassette(filename, filter_headers=[("authorization", None)])