File size: 7,389 Bytes
a083285
0832c91
a083285
 
 
 
 
 
896d563
a083285
03a6805
6988728
dc3ebef
03a6805
dc3ebef
a083285
dc3ebef
6988728
03b7855
6988728
0832c91
e8a8341
a07e9cb
896d563
 
 
 
dc3ebef
0832c91
dc3ebef
a083285
 
896d563
 
 
 
 
0832c91
 
 
 
 
dc3ebef
03a6805
a083285
 
896d563
 
 
 
 
 
 
 
 
 
0832c91
 
 
 
 
 
03a6805
 
 
0832c91
 
dc3ebef
03a6805
 
0832c91
 
 
03a6805
 
 
0832c91
 
03a6805
 
0832c91
 
 
 
 
03a6805
da1ea6b
03a6805
 
0832c91
 
03a6805
da1ea6b
03a6805
 
0832c91
 
 
 
 
 
 
 
 
 
 
03a6805
da1ea6b
03a6805
 
0832c91
 
 
 
 
 
 
 
 
 
03a6805
e7fa7ee
0832c91
a083285
 
 
0832c91
 
 
 
dc3ebef
e7fa7ee
e8ea95d
0832c91
 
 
 
df4ee6c
0832c91
 
 
 
 
 
 
 
 
 
dc3ebef
da1ea6b
e8a8341
0832c91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4524b65
da1ea6b
 
e7fa7ee
e8a8341
0832c91
 
 
 
 
 
a07e9cb
0832c91
a07e9cb
0832c91
 
 
a07e9cb
03a6805
0832c91
 
 
03a6805
da1ea6b
0832c91
03a6805
0832c91
 
 
 
 
 
 
 
 
03b7855
0832c91
 
 
 
 
 
03b7855
a083285
 
03b7855
0832c91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8ea95d
 
0832c91
 
 
 
 
 
 
 
 
 
 
a07e9cb
0832c91
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
"""For specifying an LLM agent logic flow.

This is very much a prototype. It might end up merged into LynxScribe
as an "agentic logic flow". It might just get deleted.

(This is why the dependencies are left hanging.)
"""

import os
from lynxkite.core import ops
import enum
import jinja2
import json
import numpy as np
import pandas as pd
from lynxkite.core.executors import one_by_one

jinja = jinja2.Environment()
chroma_client = None
LLM_CACHE = {}
ENV = "LLM logic"
one_by_one.register(ENV)
op = ops.op_registration(ENV)
LLM_BASE_URL = os.environ.get("LLM_BASE_URL", None)
EMBEDDING_BASE_URL = os.environ.get("EMBEDDING_BASE_URL", None)
LLM_MODEL = os.environ.get("LLM_MODEL", "gpt-4o-mini-2024-07-18")
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-small")


def chat(*args, **kwargs):
    import openai

    chat_client = openai.OpenAI(base_url=LLM_BASE_URL)
    kwargs.setdefault("model", LLM_MODEL)
    key = json.dumps(
        {"method": "chat", "base_url": LLM_BASE_URL, "args": args, "kwargs": kwargs}
    )
    if key not in LLM_CACHE:
        completion = chat_client.chat.completions.create(*args, **kwargs)
        LLM_CACHE[key] = [c.message.content for c in completion.choices]
    return LLM_CACHE[key]


def embedding(*args, **kwargs):
    import openai

    embedding_client = openai.OpenAI(base_url=EMBEDDING_BASE_URL)
    kwargs.setdefault("model", EMBEDDING_MODEL)
    key = json.dumps(
        {
            "method": "embedding",
            "base_url": EMBEDDING_BASE_URL,
            "args": args,
            "kwargs": kwargs,
        }
    )
    if key not in LLM_CACHE:
        res = embedding_client.embeddings.create(*args, **kwargs)
        [data] = res.data
        LLM_CACHE[key] = data.embedding
    return LLM_CACHE[key]


@op("Input CSV")
def input_csv(*, filename: ops.PathStr, key: str):
    return pd.read_csv(filename).rename(columns={key: "text"})


@op("Input document")
def input_document(*, filename: ops.PathStr):
    with open(filename) as f:
        return {"text": f.read()}


@op("Input chat")
def input_chat(*, chat: str):
    return {"text": chat}


@op("Split document")
def split_document(input, *, delimiter: str = "\\n\\n"):
    delimiter = delimiter.encode().decode("unicode_escape")
    chunks = input["text"].split(delimiter)
    return pd.DataFrame(chunks, columns=["text"])


@ops.input_position(input="top")
@op("Build document graph")
def build_document_graph(input):
    return [{"source": i, "target": i + 1} for i in range(len(input) - 1)]


@ops.input_position(nodes="top", edges="top")
@op("Predict links")
def predict_links(nodes, edges):
    """A placeholder for a real algorithm. For now just adds 2-hop neighbors."""
    edge_map = {}  # Source -> [Targets]
    for edge in edges:
        edge_map.setdefault(edge["source"], [])
        edge_map[edge["source"]].append(edge["target"])
    new_edges = []
    for edge in edges:
        for t in edge_map.get(edge["target"], []):
            new_edges.append({"source": edge["source"], "target": t})
    return edges + new_edges


@ops.input_position(nodes="top", edges="top")
@op("Add neighbors")
def add_neighbors(nodes, edges, item):
    nodes = pd.DataFrame(nodes)
    edges = pd.DataFrame(edges)
    matches = item["rag"]
    additional_matches = []
    for m in matches:
        node = nodes[nodes["text"] == m].index[0]
        neighbors = edges[edges["source"] == node]["target"].to_list()
        additional_matches.extend(nodes.loc[neighbors, "text"])
    return {**item, "rag": matches + additional_matches}


@op("Create prompt")
def create_prompt(input, *, save_as="prompt", template: ops.LongStr):
    assert template, (
        "Please specify the template. Refer to columns using the Jinja2 syntax."
    )
    t = jinja.from_string(template)
    prompt = t.render(**input)
    return {**input, save_as: prompt}


@op("Ask LLM")
def ask_llm(input, *, accepted_regex: str = None, max_tokens: int = 100):
    assert "prompt" in input, "Please create the prompt first."
    options = {}
    if accepted_regex:
        options["extra_body"] = {
            "regex": accepted_regex,
        }
    results = chat(
        max_tokens=max_tokens,
        messages=[
            {"role": "user", "content": input["prompt"]},
        ],
        **options,
    )
    return [{**input, "response": r} for r in results]


@op("View", view="table_view")
def view(input, *, _ctx: one_by_one.Context):
    v = _ctx.last_result
    if v:
        columns = v["dataframes"]["df"]["columns"]
        v["dataframes"]["df"]["data"].append([input[c] for c in columns])
    else:
        columns = [str(c) for c in input.keys() if not str(c).startswith("_")]
        v = {
            "dataframes": {
                "df": {
                    "columns": columns,
                    "data": [[input[c] for c in columns]],
                }
            }
        }
    return v


@ops.input_position(input="right")
@ops.output_position(output="left")
@op("Loop")
def loop(input, *, max_iterations: int = 3, _ctx: one_by_one.Context):
    """Data can flow back here max_iterations-1 times."""
    key = f"iterations-{_ctx.node.id}"
    input[key] = input.get(key, 0) + 1
    if input[key] < max_iterations:
        return input


@op("Branch", outputs=["true", "false"])
def branch(input, *, expression: str):
    res = eval(expression, input)
    return one_by_one.Output(output_handle=str(bool(res)).lower(), value=input)


class RagEngine(enum.Enum):
    Chroma = "Chroma"
    Custom = "Custom"


@ops.input_position(db="top")
@op("RAG")
def rag(
    input,
    db,
    *,
    engine: RagEngine = RagEngine.Chroma,
    input_field="text",
    db_field="text",
    num_matches: int = 10,
    _ctx: one_by_one.Context,
):
    global chroma_client
    if engine == RagEngine.Chroma:
        last = _ctx.last_result
        if last:
            collection = last["_collection"]
        else:
            collection_name = _ctx.node.id.replace(" ", "_")
            if chroma_client is None:
                import chromadb

                chroma_client = chromadb.Client()
            for c in chroma_client.list_collections():
                if c.name == collection_name:
                    chroma_client.delete_collection(name=collection_name)
            collection = chroma_client.create_collection(name=collection_name)
            collection.add(
                documents=[r[db_field] for r in db],
                ids=[str(i) for i in range(len(db))],
            )
        results = collection.query(
            query_texts=[input[input_field]],
            n_results=num_matches,
        )
        results = [db[int(r)] for r in results["ids"][0]]
        return {**input, "rag": results, "_collection": collection}
    if engine == RagEngine.Custom:
        chat = input[input_field]
        embeddings = [embedding(input=[r[db_field]]) for r in db]
        q = embedding(input=[chat])

        def cosine_similarity(a, b):
            return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

        scores = [(i, cosine_similarity(q, e)) for i, e in enumerate(embeddings)]
        scores.sort(key=lambda x: -x[1])
        matches = [db[i][db_field] for i, _ in scores[:num_matches]]
        return {**input, "rag": matches}


@op("Run Python")
def run_python(input, *, template: str):
    """TODO: Implement."""
    return input