Spaces:
Sleeping
Sleeping
File size: 4,133 Bytes
9da994b d3dec2d 9da994b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
"Module with query request and evaluation of wandb"
from openai import OpenAI
import wandb
import pandas as pd
import hydra
from omegaconf import DictConfig
import os
import datetime
from wandb.sdk.data_types.trace_tree import Trace
from dotenv import load_dotenv
load_dotenv()
api_key = os.getenv("OPENAI_API_KEY")
def run_query(client: OpenAI, system_message: str, query: str, openai_params: dict):
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": query},
]
start_time_ms = datetime.datetime.now().timestamp() * 1000
try:
if not openai_params["stream"]:
response = client.chat.completions.create(
**openai_params,
messages=messages,
)
end_time_ms = datetime.datetime.now().timestamp() * 1000
status = "success"
status_message = (None,)
response_text = response.choices[0].message.content
token_usage = dict(response.usage)
# stream
else:
response = client.chat.completions.create(
**openai_params, messages=messages
)
end_time_ms = datetime.datetime.now().timestamp() * 1000
status = "success"
status_message = (None,)
collected_messages = []
for chunk in response:
chunk_message = chunk.choices[0].delta.content # extract the message
collected_messages.append(chunk_message) #
# clean None in collected_messages
collected_messages = [m for m in collected_messages if m is not None]
response_text = "".join([m for m in collected_messages])
token_usage = "no information with stream"
except Exception as e:
end_time_ms = datetime.datetime.now().timestamp() * 1000
status = "error"
status_message = str(e)
token_usage = {}
response_text = "error"
return {
"status": status,
"status_message": status_message,
"running_time_ms": end_time_ms - start_time_ms,
"token_usage": token_usage,
"response_text": response_text,
}
@hydra.main(config_path="../../conf", config_name="train_llm.yaml")
def run_query_on_wandb(cfg: DictConfig):
"""Run Openai LLM and log results on wandb. Config file in conf/train_llm.yaml
Args:
cfg (DictConfig): configuration file for parameters
"""
run = wandb.init(
project=cfg.main.project_name,
group=cfg.main.experiment_name,
config=cfg.openai_parameters,
job_type="train_llm",
)
artifact = run.use_artifact(cfg.parameters.data)
artifact_path = artifact.file()
data_frame = pd.read_csv(artifact_path, on_bad_lines="warn").iloc[:, 0].values
artifact_st = run.use_artifact(cfg.parameters.system_template)
artifact_st_path = artifact_st.file()
system_message = open(artifact_st_path).read()
client = OpenAI(api_key=api_key)
for _, query in enumerate(data_frame):
res = run_query(
client=client,
system_message=system_message,
query=query,
openai_params=cfg.openai_parameters,
)
# create a span in wandb
root_span = Trace(
name="root_span",
kind="llm", # kind can be "llm", "chain", "agent" or "tool"
status_code=res["status"],
status_message=res["status_message"],
metadata={
"temperature": cfg.openai_parameters.temperature,
"token_usage": res["token_usage"],
"model_name": cfg.openai_parameters.model,
},
# start_time_ms=res["start_time_ms"],
# end_time_ms=res["end_time_ms"],
inputs={
"query": query,
"system_prompt": system_message,
},
outputs={"response": res["response_text"]},
)
# log the span to wandb
root_span.log(name="openai_trace")
if __name__ == "__main__":
run_query_on_wandb()
|