marcellopoliti's picture
update system template
d3dec2d
"Module with query request and evaluation of wandb"
from openai import OpenAI
import wandb
import pandas as pd
import hydra
from omegaconf import DictConfig
import os
import datetime
from wandb.sdk.data_types.trace_tree import Trace
from dotenv import load_dotenv
load_dotenv()
api_key = os.getenv("OPENAI_API_KEY")
def run_query(client: OpenAI, system_message: str, query: str, openai_params: dict):
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": query},
]
start_time_ms = datetime.datetime.now().timestamp() * 1000
try:
if not openai_params["stream"]:
response = client.chat.completions.create(
**openai_params,
messages=messages,
)
end_time_ms = datetime.datetime.now().timestamp() * 1000
status = "success"
status_message = (None,)
response_text = response.choices[0].message.content
token_usage = dict(response.usage)
# stream
else:
response = client.chat.completions.create(
**openai_params, messages=messages
)
end_time_ms = datetime.datetime.now().timestamp() * 1000
status = "success"
status_message = (None,)
collected_messages = []
for chunk in response:
chunk_message = chunk.choices[0].delta.content # extract the message
collected_messages.append(chunk_message) #
# clean None in collected_messages
collected_messages = [m for m in collected_messages if m is not None]
response_text = "".join([m for m in collected_messages])
token_usage = "no information with stream"
except Exception as e:
end_time_ms = datetime.datetime.now().timestamp() * 1000
status = "error"
status_message = str(e)
token_usage = {}
response_text = "error"
return {
"status": status,
"status_message": status_message,
"running_time_ms": end_time_ms - start_time_ms,
"token_usage": token_usage,
"response_text": response_text,
}
@hydra.main(config_path="../../conf", config_name="train_llm.yaml")
def run_query_on_wandb(cfg: DictConfig):
"""Run Openai LLM and log results on wandb. Config file in conf/train_llm.yaml
Args:
cfg (DictConfig): configuration file for parameters
"""
run = wandb.init(
project=cfg.main.project_name,
group=cfg.main.experiment_name,
config=cfg.openai_parameters,
job_type="train_llm",
)
artifact = run.use_artifact(cfg.parameters.data)
artifact_path = artifact.file()
data_frame = pd.read_csv(artifact_path, on_bad_lines="warn").iloc[:, 0].values
artifact_st = run.use_artifact(cfg.parameters.system_template)
artifact_st_path = artifact_st.file()
system_message = open(artifact_st_path).read()
client = OpenAI(api_key=api_key)
for _, query in enumerate(data_frame):
res = run_query(
client=client,
system_message=system_message,
query=query,
openai_params=cfg.openai_parameters,
)
# create a span in wandb
root_span = Trace(
name="root_span",
kind="llm", # kind can be "llm", "chain", "agent" or "tool"
status_code=res["status"],
status_message=res["status_message"],
metadata={
"temperature": cfg.openai_parameters.temperature,
"token_usage": res["token_usage"],
"model_name": cfg.openai_parameters.model,
},
# start_time_ms=res["start_time_ms"],
# end_time_ms=res["end_time_ms"],
inputs={
"query": query,
"system_prompt": system_message,
},
outputs={"response": res["response_text"]},
)
# log the span to wandb
root_span.log(name="openai_trace")
if __name__ == "__main__":
run_query_on_wandb()