"Module with query request and evaluation of wandb" from openai import OpenAI import wandb import pandas as pd import hydra from omegaconf import DictConfig import os import datetime from wandb.sdk.data_types.trace_tree import Trace from dotenv import load_dotenv load_dotenv() api_key = os.getenv("OPENAI_API_KEY") def run_query(client: OpenAI, system_message: str, query: str, openai_params: dict): messages = [ {"role": "system", "content": system_message}, {"role": "user", "content": query}, ] start_time_ms = datetime.datetime.now().timestamp() * 1000 try: if not openai_params["stream"]: response = client.chat.completions.create( **openai_params, messages=messages, ) end_time_ms = datetime.datetime.now().timestamp() * 1000 status = "success" status_message = (None,) response_text = response.choices[0].message.content token_usage = dict(response.usage) # stream else: response = client.chat.completions.create( **openai_params, messages=messages ) end_time_ms = datetime.datetime.now().timestamp() * 1000 status = "success" status_message = (None,) collected_messages = [] for chunk in response: chunk_message = chunk.choices[0].delta.content # extract the message collected_messages.append(chunk_message) # # clean None in collected_messages collected_messages = [m for m in collected_messages if m is not None] response_text = "".join([m for m in collected_messages]) token_usage = "no information with stream" except Exception as e: end_time_ms = datetime.datetime.now().timestamp() * 1000 status = "error" status_message = str(e) token_usage = {} response_text = "error" return { "status": status, "status_message": status_message, "running_time_ms": end_time_ms - start_time_ms, "token_usage": token_usage, "response_text": response_text, } @hydra.main(config_path="../../conf", config_name="train_llm.yaml") def run_query_on_wandb(cfg: DictConfig): """Run Openai LLM and log results on wandb. Config file in conf/train_llm.yaml Args: cfg (DictConfig): configuration file for parameters """ run = wandb.init( project=cfg.main.project_name, group=cfg.main.experiment_name, config=cfg.openai_parameters, job_type="train_llm", ) artifact = run.use_artifact(cfg.parameters.data) artifact_path = artifact.file() data_frame = pd.read_csv(artifact_path, on_bad_lines="warn").iloc[:, 0].values artifact_st = run.use_artifact(cfg.parameters.system_template) artifact_st_path = artifact_st.file() system_message = open(artifact_st_path).read() client = OpenAI(api_key=api_key) for _, query in enumerate(data_frame): res = run_query( client=client, system_message=system_message, query=query, openai_params=cfg.openai_parameters, ) # create a span in wandb root_span = Trace( name="root_span", kind="llm", # kind can be "llm", "chain", "agent" or "tool" status_code=res["status"], status_message=res["status_message"], metadata={ "temperature": cfg.openai_parameters.temperature, "token_usage": res["token_usage"], "model_name": cfg.openai_parameters.model, }, # start_time_ms=res["start_time_ms"], # end_time_ms=res["end_time_ms"], inputs={ "query": query, "system_prompt": system_message, }, outputs={"response": res["response_text"]}, ) # log the span to wandb root_span.log(name="openai_trace") if __name__ == "__main__": run_query_on_wandb()