Spaces:
Paused
Paused
| import json | |
| import time | |
| import traceback | |
| import openai | |
| import requests | |
| import streamlit as st | |
| import utils | |
| SEED = 42 | |
| def get_client(): | |
| return openai.OpenAI(api_key = utils.OPENAI_API_KEY,organization=utils.OPENAI_ORGANIZATION_ID) | |
| def getListOfCompanies(query, filters = {}): | |
| country_filters = filters['country'] if 'country' in filters else st.session_state.country | |
| st.session_state.db_search_results = utils.search_index(query, st.session_state.top_k, st.session_state.region, country_filters, st.session_state.retriever, st.session_state.index_namespace) | |
| descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in st.session_state.db_search_results[:20] if 'Summary' in res['data']]) | |
| return descriptions | |
| def wait_for_response(thread, run): | |
| while True: | |
| # Retrieve the run status | |
| run_status = st.session_state.openai_client.beta.threads.runs.retrieve( | |
| thread_id=thread.id, | |
| run_id=run.id | |
| ) | |
| # Check and print the step details | |
| run_steps = st.session_state.openai_client.beta.threads.runs.steps.list( | |
| thread_id=thread.id, | |
| run_id=run.id | |
| ) | |
| for step in run_steps.data: | |
| #print(step) | |
| if step.type == 'tool_calls': | |
| print(f"\n--------------------\nTool {step.type} invoked.\n--------------------\n") | |
| # If step involves code execution, print the code | |
| if step.type == 'code_interpreter': | |
| print(f"Python Code Executed: {step.step_details['code_interpreter']['input']}") | |
| if run_status.status == 'completed': | |
| # Retrieve all messages from the thread | |
| messages = st.session_state.openai_client.beta.threads.messages.list( | |
| thread_id=thread.id | |
| ) | |
| # Print all messages from the thread | |
| for msg in messages.data: | |
| role = msg.role | |
| content = msg.content[0].text.value | |
| print(f"{role.capitalize()}: {content}") | |
| return messages | |
| elif run_status.status in ['queued', 'in_progress']: | |
| print(f'{run_status.status.capitalize()}... Please wait.') | |
| time.sleep(1.5) # Wait before checking again | |
| elif run_status.status == 'requires_action': | |
| required_action = run_status.required_action | |
| if required_action.type == 'submit_tool_outputs': | |
| outputs = {} | |
| for tool_call in required_action.submit_tool_outputs.tool_calls: | |
| if tool_call.function.name =="getListOfCompanies": | |
| try: | |
| args = json.loads(tool_call.function.arguments) | |
| res = '' | |
| if 'query' in args: | |
| print(f"Processing tool_call {tool_call.id}. Calling 'getListOfCompanies with args: {args}" ) | |
| search_filters = json.loads(args['filters']) if 'filters' in args else [] | |
| res = getListOfCompanies(args['query'], search_filters) | |
| outputs[tool_call.id] = res | |
| except Exception as e: | |
| print(f"Error calling tools, {str(e)}") | |
| traceback.print_exc() | |
| tool_outputs=[{"tool_call_id": k, "output": v} for (k,v) in outputs.items()] | |
| print(f"Finished tools calling: {tool_outputs}") | |
| run = st.session_state.openai_client.beta.threads.runs.submit_tool_outputs( | |
| thread_id=thread.id, | |
| run_id=run.id, | |
| tool_outputs=tool_outputs, | |
| ) | |
| print(run_status.required_action) | |
| #return run_status | |
| else: | |
| print(f"Unknown required action: {required_action.type}") | |
| return run_status | |
| else: | |
| print(f"Run status: {run_status.status}") | |
| return run_status | |
| def call_openai(prompt, engine="gpt-3.5-turbo", temp=0, top_p=1.0, max_tokens=4048): | |
| if st.session_state.report_type=="assistant": | |
| try: | |
| thread = st.session_state.assistant_thread | |
| assistant_id = st.session_state.assistant_id | |
| message = st.session_state.openai_client.beta.threads.messages.create( | |
| thread.id, | |
| role="user", | |
| content=prompt, | |
| ) | |
| run = st.session_state.openai_client.beta.threads.runs.create( | |
| thread_id=thread.id, | |
| assistant_id=assistant_id, | |
| ) | |
| messages = wait_for_response(thread, run) | |
| print(f"====================\nOpen AI response\n {messages}\n====================\n") | |
| text = "" | |
| for message in messages: | |
| text = text + "\n" + message.content[0].text.value | |
| return text | |
| except Exception as e: | |
| #except openai.error.OpenAIError as e: | |
| print(f"An error occurred: {str(e)}") | |
| else: | |
| try: | |
| response = st.session_state.openai_client.chat.completions.create( | |
| model=engine, | |
| messages=st.session_state.messages + [{"role": "user", "content": prompt}], | |
| temperature=temp, | |
| seed = SEED, | |
| max_tokens=max_tokens | |
| ) | |
| print(f"====================\nOpen AI response\n {response}\n====================\n") | |
| text = response.choices[0].message.content.strip() | |
| return text | |
| except Exception as e: | |
| #except openai.error.OpenAIError as e: | |
| print(f"An error occurred: {str(e)}") | |
| return "Failed to generate a response." | |
| def send_message(role, content): | |
| message = st.session_state.openai_client.beta.threads.messages.create( | |
| thread_id=st.session_state.assistant_thread.id, | |
| role=role, | |
| content=content | |
| ) | |
| def start_conversation(): | |
| st.session_state.assistant_thread = st.session_state.openai_client.beta.threads.create() | |
| def run_assistant(): | |
| run = st.session_state.openai_client.beta.threads.runs.create( | |
| thread_id=st.session_state.assistant_thread.id, | |
| assistant_id=st.session_state.assistant.id, | |
| ) | |
| while run.status == "queued" or run.status == "in_progress": | |
| run = st.session_state.openai_client.beta.threads.runs.retrieve( | |
| thread_id=st.session_state.assistant_thread.id, | |
| run_id=run.id, | |
| ) | |
| time.sleep(0.5) | |
| return run | |