File size: 4,133 Bytes
9da994b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3dec2d
 
9da994b
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"Module with query request and evaluation of wandb"

from openai import OpenAI
import wandb
import pandas as pd
import hydra
from omegaconf import DictConfig
import os
import datetime
from wandb.sdk.data_types.trace_tree import Trace
from dotenv import load_dotenv

load_dotenv()
api_key = os.getenv("OPENAI_API_KEY")


def run_query(client: OpenAI, system_message: str, query: str, openai_params: dict):
    messages = [
        {"role": "system", "content": system_message},
        {"role": "user", "content": query},
    ]
    start_time_ms = datetime.datetime.now().timestamp() * 1000

    try:
        if not openai_params["stream"]:
            response = client.chat.completions.create(
                **openai_params,
                messages=messages,
            )

            end_time_ms = datetime.datetime.now().timestamp() * 1000
            status = "success"
            status_message = (None,)
            response_text = response.choices[0].message.content
            token_usage = dict(response.usage)
            # stream
        else:
            response = client.chat.completions.create(
                **openai_params, messages=messages
            )
            end_time_ms = datetime.datetime.now().timestamp() * 1000
            status = "success"
            status_message = (None,)
            collected_messages = []
            for chunk in response:
                chunk_message = chunk.choices[0].delta.content  # extract the message
                collected_messages.append(chunk_message)  #
            # clean None in collected_messages
            collected_messages = [m for m in collected_messages if m is not None]
            response_text = "".join([m for m in collected_messages])
            token_usage = "no information with stream"
    except Exception as e:
        end_time_ms = datetime.datetime.now().timestamp() * 1000
        status = "error"
        status_message = str(e)
        token_usage = {}
        response_text = "error"

    return {
        "status": status,
        "status_message": status_message,
        "running_time_ms": end_time_ms - start_time_ms,
        "token_usage": token_usage,
        "response_text": response_text,
    }


@hydra.main(config_path="../../conf", config_name="train_llm.yaml")
def run_query_on_wandb(cfg: DictConfig):
    """Run Openai LLM and log results on wandb. Config file in conf/train_llm.yaml

    Args:
        cfg (DictConfig): configuration file for parameters
    """

    run = wandb.init(
        project=cfg.main.project_name,
        group=cfg.main.experiment_name,
        config=cfg.openai_parameters,
        job_type="train_llm",
    )

    artifact = run.use_artifact(cfg.parameters.data)
    artifact_path = artifact.file()
    data_frame = pd.read_csv(artifact_path, on_bad_lines="warn").iloc[:, 0].values

    artifact_st = run.use_artifact(cfg.parameters.system_template)
    artifact_st_path = artifact_st.file()
    system_message = open(artifact_st_path).read()

    client = OpenAI(api_key=api_key)

    for _, query in enumerate(data_frame):
        res = run_query(
            client=client,
            system_message=system_message,
            query=query,
            openai_params=cfg.openai_parameters,
        )

        # create a span in wandb
        root_span = Trace(
            name="root_span",
            kind="llm",  # kind can be "llm", "chain", "agent" or "tool"
            status_code=res["status"],
            status_message=res["status_message"],
            metadata={
                "temperature": cfg.openai_parameters.temperature,
                "token_usage": res["token_usage"],
                "model_name": cfg.openai_parameters.model,
            },
            # start_time_ms=res["start_time_ms"],
            # end_time_ms=res["end_time_ms"],
            inputs={
                "query": query,
                "system_prompt": system_message,
            },
            outputs={"response": res["response_text"]},
        )
        # log the span to wandb
        root_span.log(name="openai_trace")


if __name__ == "__main__":
    run_query_on_wandb()