# the async version is adapted from https://gist.github.com/neubig/80de662fb3e225c18172ec218be4917a from __future__ import annotations import os import yaml import openai import ast import pdb import asyncio from typing import Any, List import os import pathlib import openai # from factool.env_config import factool_env_config # env # openai.api_key = factool_env_config.openai_api_key class OpenAIChat(): def __init__( self, model_name='gpt-3.5-turbo', max_tokens=2500, temperature=0, top_p=1, request_timeout=60, ): openai.api_key = os.environ.get("OPENAI_API_KEY", None) assert openai.api_key is not None, "Please set the OPENAI_API_KEY environment variable." if 'gpt' not in model_name: openai.api_base = "http://localhost:8000/v1" self.config = { 'model_name': model_name, 'max_tokens': max_tokens, 'temperature': temperature, 'top_p': top_p, 'request_timeout': request_timeout, } def _boolean_fix(self, output): return output.replace("true", "True").replace("false", "False") def _type_check(self, output, expected_type): try: output_eval = ast.literal_eval(output) if not isinstance(output_eval, expected_type): return None return output_eval except: return None async def dispatch_openai_requests( self, messages_list, ) -> list[str]: """Dispatches requests to OpenAI API asynchronously. Args: messages_list: List of messages to be sent to OpenAI ChatCompletion API. Returns: List of responses from OpenAI API. """ async def _request_with_retry(messages, retry=3): for _ in range(retry): try: response = await openai.ChatCompletion.acreate( model=self.config['model_name'], messages=messages, max_tokens=self.config['max_tokens'], temperature=self.config['temperature'], top_p=self.config['top_p'], request_timeout=self.config['request_timeout'], ) return response except openai.error.RateLimitError: print('Rate limit error, waiting for 40 second...') await asyncio.sleep(40) except openai.error.APIError: print('API error, waiting for 1 second...') await asyncio.sleep(1) except openai.error.Timeout: print('Timeout error, waiting for 1 second...') await asyncio.sleep(1) except openai.error.ServiceUnavailableError: print('Service unavailable error, waiting for 3 second...') await asyncio.sleep(3) except openai.error.APIConnectionError: print('API Connection error, waiting for 3 second...') await asyncio.sleep(3) return None async_responses = [ _request_with_retry(messages) for messages in messages_list ] return await asyncio.gather(*async_responses) async def async_run(self, messages_list, expected_type): retry = 1 responses = [None for _ in range(len(messages_list))] messages_list_cur_index = [i for i in range(len(messages_list))] while retry > 0 and len(messages_list_cur_index) > 0: print(f'{retry} retry left...') messages_list_cur = [messages_list[i] for i in messages_list_cur_index] predictions = await self.dispatch_openai_requests( messages_list=messages_list_cur, ) preds = [self._type_check(self._boolean_fix(prediction['choices'][0]['message']['content']), expected_type) if prediction is not None else None for prediction in predictions] finised_index = [] for i, pred in enumerate(preds): if pred is not None: responses[messages_list_cur_index[i]] = pred finised_index.append(messages_list_cur_index[i]) messages_list_cur_index = [i for i in messages_list_cur_index if i not in finised_index] retry -= 1 return responses class OpenAIEmbed(): def __init__(): openai.api_key = os.environ.get("OPENAI_API_KEY", None) assert openai.api_key is not None, "Please set the OPENAI_API_KEY environment variable." async def create_embedding(self, text, retry=3): for _ in range(retry): try: response = await openai.Embedding.acreate(input=text, model="text-embedding-ada-002") return response except openai.error.RateLimitError: print('Rate limit error, waiting for 1 second...') await asyncio.sleep(1) except openai.error.APIError: print('API error, waiting for 1 second...') await asyncio.sleep(1) except openai.error.Timeout: print('Timeout error, waiting for 1 second...') await asyncio.sleep(1) return None async def process_batch(self, batch, retry=3): tasks = [self.create_embedding(text, retry=retry) for text in batch] return await asyncio.gather(*tasks) if __name__ == "__main__": chat = OpenAIChat() predictions = chat.async_run( messages_list=[ [{"role": "user", "content": "show either 'ab' or '['a']'. Do not do anything else."}], ] * 20, expected_type=List, ) # Usage embed = OpenAIEmbed() batch = ["string1", "string2", "string3", "string4", "string5", "string6", "string7", "string8", "string9", "string10"] # Your batch of strings embeddings = asyncio.run(embed.process_batch(batch, retry=3)) for embedding in embeddings: print(embedding["data"][0]["embedding"])