Spaces:
Sleeping
Sleeping
"Module with query request and evaluation of wandb" | |
from openai import OpenAI | |
import wandb | |
import pandas as pd | |
import hydra | |
from omegaconf import DictConfig | |
import os | |
import datetime | |
from wandb.sdk.data_types.trace_tree import Trace | |
from dotenv import load_dotenv | |
load_dotenv() | |
api_key = os.getenv("OPENAI_API_KEY") | |
def run_query(client: OpenAI, system_message: str, query: str, openai_params: dict): | |
messages = [ | |
{"role": "system", "content": system_message}, | |
{"role": "user", "content": query}, | |
] | |
start_time_ms = datetime.datetime.now().timestamp() * 1000 | |
try: | |
if not openai_params["stream"]: | |
response = client.chat.completions.create( | |
**openai_params, | |
messages=messages, | |
) | |
end_time_ms = datetime.datetime.now().timestamp() * 1000 | |
status = "success" | |
status_message = (None,) | |
response_text = response.choices[0].message.content | |
token_usage = dict(response.usage) | |
# stream | |
else: | |
response = client.chat.completions.create( | |
**openai_params, messages=messages | |
) | |
end_time_ms = datetime.datetime.now().timestamp() * 1000 | |
status = "success" | |
status_message = (None,) | |
collected_messages = [] | |
for chunk in response: | |
chunk_message = chunk.choices[0].delta.content # extract the message | |
collected_messages.append(chunk_message) # | |
# clean None in collected_messages | |
collected_messages = [m for m in collected_messages if m is not None] | |
response_text = "".join([m for m in collected_messages]) | |
token_usage = "no information with stream" | |
except Exception as e: | |
end_time_ms = datetime.datetime.now().timestamp() * 1000 | |
status = "error" | |
status_message = str(e) | |
token_usage = {} | |
response_text = "error" | |
return { | |
"status": status, | |
"status_message": status_message, | |
"running_time_ms": end_time_ms - start_time_ms, | |
"token_usage": token_usage, | |
"response_text": response_text, | |
} | |
def run_query_on_wandb(cfg: DictConfig): | |
"""Run Openai LLM and log results on wandb. Config file in conf/train_llm.yaml | |
Args: | |
cfg (DictConfig): configuration file for parameters | |
""" | |
run = wandb.init( | |
project=cfg.main.project_name, | |
group=cfg.main.experiment_name, | |
config=cfg.openai_parameters, | |
job_type="train_llm", | |
) | |
artifact = run.use_artifact(cfg.parameters.data) | |
artifact_path = artifact.file() | |
data_frame = pd.read_csv(artifact_path, on_bad_lines="warn").iloc[:, 0].values | |
artifact_st = run.use_artifact(cfg.parameters.system_template) | |
artifact_st_path = artifact_st.file() | |
system_message = open(artifact_st_path).read() | |
client = OpenAI(api_key=api_key) | |
for _, query in enumerate(data_frame): | |
res = run_query( | |
client=client, | |
system_message=system_message, | |
query=query, | |
openai_params=cfg.openai_parameters, | |
) | |
# create a span in wandb | |
root_span = Trace( | |
name="root_span", | |
kind="llm", # kind can be "llm", "chain", "agent" or "tool" | |
status_code=res["status"], | |
status_message=res["status_message"], | |
metadata={ | |
"temperature": cfg.openai_parameters.temperature, | |
"token_usage": res["token_usage"], | |
"model_name": cfg.openai_parameters.model, | |
}, | |
# start_time_ms=res["start_time_ms"], | |
# end_time_ms=res["end_time_ms"], | |
inputs={ | |
"query": query, | |
"system_prompt": system_message, | |
}, | |
outputs={"response": res["response_text"]}, | |
) | |
# log the span to wandb | |
root_span.log(name="openai_trace") | |
if __name__ == "__main__": | |
run_query_on_wandb() | |