semsearch / openai_utils.py
hanoch.rahimi@gmail
increase number retreived items, prevent resending the same query
ba1f3e2
raw
history blame
7.92 kB
import json
import time
import traceback
import openai
import requests
import streamlit as st
import utils
SEED = 42
def get_client():
return openai.OpenAI(api_key = utils.OPENAI_API_KEY,organization=utils.OPENAI_ORGANIZATION_ID)
def getListOfCompanies(query, filters = {}):
country_filters = filters['country'] if 'country' in filters else st.session_state.country
st.session_state.db_search_results = utils.search_index(query, st.session_state.top_k, st.session_state.region, country_filters, st.session_state.retriever, st.session_state.index_namespace)
descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in st.session_state.db_search_results[:20] if 'Summary' in res['data']])
return descriptions
def report_error(txt):
print(f"\nEEEEEEEEEEEEE\n{txt}")
def wait_for_response(thread, run):
timeout = 60 #timeout in seconds
started = time.time()
while True and time.time()-started<timeout:
# Retrieve the run status
run_status = st.session_state.openai_client.beta.threads.runs.retrieve(
thread_id=thread.id,
run_id=run.id
)
print(f"Run status: {run_status.status}")
# Check and print the step details
run_steps = st.session_state.openai_client.beta.threads.runs.steps.list(
thread_id=thread.id,
run_id=run.id
)
for step in run_steps.data:
#print(step)
if step.type == 'tool_calls':
print(f"\n--------------------\nTool {step.type} invoked.\n--------------------\n")
# If step involves code execution, print the code
if step.type == 'code_interpreter':
print(f"Python Code Executed: {step.step_details['code_interpreter']['input']}")
if run_status.status == 'completed':
# Retrieve all messages from the thread
messages = st.session_state.openai_client.beta.threads.messages.list(
thread_id=thread.id
)
# Print all messages from the thread
for msg in messages.data:
role = msg.role
content = msg.content[0].text.value
print(f"{role.capitalize()}: {content}")
return messages
elif run_status.status in ['queued', 'in_progress']:
print(f'{run_status.status.capitalize()}... Please wait.')
time.sleep(1.5) # Wait before checking again
elif run_status.status == 'requires_action':
required_action = run_status.required_action
if required_action.type == 'submit_tool_outputs':
print(f"Requires tool outputs: {required_action}")
outputs = {}
for tool_call in required_action.submit_tool_outputs.tool_calls:
if tool_call.function.name =="getListOfCompanies":
try:
args = json.loads(tool_call.function.arguments)
res = ''
if 'query' in args:
print(f"Processing tool_call {tool_call.id}. Calling 'getListOfCompanies with args: {args}" )
search_filters = json.loads(args['filters']) if 'filters' in args else {}
res = getListOfCompanies(args['query'], search_filters)
outputs[tool_call.id] = res
except Exception as e:
print(f"Error calling tools, {str(e)}")
traceback.print_exc()
tool_outputs=[{"tool_call_id": k, "output": v} for (k,v) in outputs.items()]
print(f"Finished tools calling: {str(tool_outputs)[:400]}")
run = st.session_state.openai_client.beta.threads.runs.submit_tool_outputs(
thread_id=thread.id,
run_id=run.id,
tool_outputs=tool_outputs,
)
print(f"Required action {run_status.required_action}")
#return run_status
else:
report_error(f"Unknown required action type: {required_action}")
break
else:
report_error(f"Unhandled Run status: {run_status.status}\n\nError: {run_status.last_error}\n")
break
if time.time()-started>timeout:
report_error(f"Wait for response timeout after {timeout}")
report_error(f"Flow not completed")
messages = st.session_state.openai_client.beta.threads.messages.list(
thread_id=thread.id
)
return messages
def call_assistant(query, engine="gpt-3.5-turbo"): #, temp=0, top_p=1.0, max_tokens=4048):
#Prevent re sending the last message over and over
print(f"Last query {st.session_state.last_user_query}, current query {query}")
if st.session_state.last_user_query == query:
report_error(f"That query '{query}' was just sent. We don't send the same query twice in a row. ")
return st.session_state.messages
try:
thread = st.session_state.assistant_thread
assistant_id = st.session_state.assistant_id
message = st.session_state.openai_client.beta.threads.messages.create(
thread.id,
role="user",
content=query,
)
run = st.session_state.openai_client.beta.threads.runs.create(
thread_id=thread.id,
assistant_id=assistant_id,
)
messages = wait_for_response(thread, run)
print(f"====================\nOpen AI response\n {str(messages)[:1000]}\n====================\n")
return messages
# text = ""
# for message in messages:
# print(message)
# text = text + "\n" + message.content[0].text.value
# return text
except Exception as e:
#except openai.error.OpenAIError as e:
print(f"An error occurred: {str(e)}")
def call_openai(prompt, engine="gpt-3.5-turbo", temp=0, top_p=1.0, max_tokens=4048):
if st.session_state.report_type=="assistant":
raise Exception("use call_assistant instead of call_openai")
else:
try:
response = st.session_state.openai_client.chat.completions.create(
model=engine,
messages=st.session_state.messages + [{"role": "user", "content": prompt}],
temperature=temp,
seed = SEED,
max_tokens=max_tokens
)
print(f"====================\nOpen AI response\n {response}\n====================\n")
text = response.choices[0].message.content.strip()
return text
except Exception as e:
#except openai.error.OpenAIError as e:
print(f"An error occurred: {str(e)}")
return "Failed to generate a response."
def send_message(role, content):
message = st.session_state.openai_client.beta.threads.messages.create(
thread_id=st.session_state.assistant_thread.id,
role=role,
content=content
)
def start_conversation():
st.session_state.assistant_thread = st.session_state.openai_client.beta.threads.create()
def run_assistant():
run = st.session_state.openai_client.beta.threads.runs.create(
thread_id=st.session_state.assistant_thread.id,
assistant_id=st.session_state.assistant.id,
)
while run.status == "queued" or run.status == "in_progress":
run = st.session_state.openai_client.beta.threads.runs.retrieve(
thread_id=st.session_state.assistant_thread.id,
run_id=run.id,
)
time.sleep(0.5)
return run