Spaces:
Running
Running
Format/lint changes.
Browse files- server/llm_ops.py +154 -125
- web/src/workspace/nodes/LynxKiteNode.tsx +0 -3
server/llm_ops.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
-
|
|
|
2 |
from . import ops
|
3 |
import chromadb
|
4 |
import enum
|
@@ -14,177 +15,205 @@ embedding_client = openai.OpenAI(base_url="http://localhost:7997/")
|
|
14 |
jinja = jinja2.Environment()
|
15 |
chroma_client = chromadb.Client()
|
16 |
LLM_CACHE = {}
|
17 |
-
ENV =
|
18 |
one_by_one.register(ENV)
|
19 |
op = ops.op_registration(ENV)
|
20 |
|
|
|
21 |
def chat(*args, **kwargs):
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
|
|
27 |
|
28 |
def embedding(*args, **kwargs):
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
35 |
|
36 |
@op("Input CSV")
|
37 |
def input_csv(*, filename: ops.PathStr, key: str):
|
38 |
-
|
|
|
39 |
|
40 |
@op("Input document")
|
41 |
def input_document(*, filename: ops.PathStr):
|
42 |
-
|
43 |
-
|
|
|
44 |
|
45 |
@op("Input chat")
|
46 |
def input_chat(*, chat: str):
|
47 |
-
|
|
|
48 |
|
49 |
@op("Split document")
|
50 |
-
def split_document(input, *, delimiter: str =
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
54 |
|
55 |
@ops.input_position(input="top")
|
56 |
@op("Build document graph")
|
57 |
def build_document_graph(input):
|
58 |
-
|
|
|
59 |
|
60 |
@ops.input_position(nodes="top", edges="top")
|
61 |
@op("Predict links")
|
62 |
def predict_links(nodes, edges):
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
|
|
73 |
|
74 |
@ops.input_position(nodes="top", edges="top")
|
75 |
@op("Add neighbors")
|
76 |
def add_neighbors(nodes, edges, item):
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
|
|
86 |
|
87 |
@op("Create prompt")
|
88 |
-
def create_prompt(input, *, save_as=
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
|
|
|
|
|
|
93 |
|
94 |
@op("Ask LLM")
|
95 |
def ask_llm(input, *, model: str, accepted_regex: str = None, max_tokens: int = 100):
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
|
|
112 |
|
113 |
@op("View", view="table_view")
|
114 |
def view(input, *, _ctx: one_by_one.Context):
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
|
|
|
|
|
|
128 |
|
129 |
@ops.input_position(input="right")
|
130 |
@ops.output_position(output="left")
|
131 |
@op("Loop")
|
132 |
def loop(input, *, max_iterations: int = 3, _ctx: one_by_one.Context):
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
|
|
138 |
|
139 |
-
@op(
|
140 |
def branch(input, *, expression: str):
|
141 |
-
|
142 |
-
|
|
|
143 |
|
144 |
class RagEngine(enum.Enum):
|
145 |
-
|
146 |
-
|
|
|
147 |
|
148 |
@ops.input_position(db="top")
|
149 |
-
@op(
|
150 |
def rag(
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
if
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
def run_python(input, *, template: str):
|
189 |
-
|
190 |
-
|
|
|
1 |
+
"""For specifying an LLM agent logic flow."""
|
2 |
+
|
3 |
from . import ops
|
4 |
import chromadb
|
5 |
import enum
|
|
|
15 |
jinja = jinja2.Environment()
|
16 |
chroma_client = chromadb.Client()
|
17 |
LLM_CACHE = {}
|
18 |
+
ENV = "LLM logic"
|
19 |
one_by_one.register(ENV)
|
20 |
op = ops.op_registration(ENV)
|
21 |
|
22 |
+
|
23 |
def chat(*args, **kwargs):
|
24 |
+
key = json.dumps({"method": "chat", "args": args, "kwargs": kwargs})
|
25 |
+
if key not in LLM_CACHE:
|
26 |
+
completion = chat_client.chat.completions.create(*args, **kwargs)
|
27 |
+
LLM_CACHE[key] = [c.message.content for c in completion.choices]
|
28 |
+
return LLM_CACHE[key]
|
29 |
+
|
30 |
|
31 |
def embedding(*args, **kwargs):
|
32 |
+
key = json.dumps({"method": "embedding", "args": args, "kwargs": kwargs})
|
33 |
+
if key not in LLM_CACHE:
|
34 |
+
res = embedding_client.embeddings.create(*args, **kwargs)
|
35 |
+
[data] = res.data
|
36 |
+
LLM_CACHE[key] = data.embedding
|
37 |
+
return LLM_CACHE[key]
|
38 |
+
|
39 |
|
40 |
@op("Input CSV")
|
41 |
def input_csv(*, filename: ops.PathStr, key: str):
|
42 |
+
return pd.read_csv(filename).rename(columns={key: "text"})
|
43 |
+
|
44 |
|
45 |
@op("Input document")
|
46 |
def input_document(*, filename: ops.PathStr):
|
47 |
+
with open(filename) as f:
|
48 |
+
return {"text": f.read()}
|
49 |
+
|
50 |
|
51 |
@op("Input chat")
|
52 |
def input_chat(*, chat: str):
|
53 |
+
return {"text": chat}
|
54 |
+
|
55 |
|
56 |
@op("Split document")
|
57 |
+
def split_document(input, *, delimiter: str = "\\n\\n"):
|
58 |
+
delimiter = delimiter.encode().decode("unicode_escape")
|
59 |
+
chunks = input["text"].split(delimiter)
|
60 |
+
return pd.DataFrame(chunks, columns=["text"])
|
61 |
+
|
62 |
|
63 |
@ops.input_position(input="top")
|
64 |
@op("Build document graph")
|
65 |
def build_document_graph(input):
|
66 |
+
return [{"source": i, "target": i + 1} for i in range(len(input) - 1)]
|
67 |
+
|
68 |
|
69 |
@ops.input_position(nodes="top", edges="top")
|
70 |
@op("Predict links")
|
71 |
def predict_links(nodes, edges):
|
72 |
+
"""A placeholder for a real algorithm. For now just adds 2-hop neighbors."""
|
73 |
+
edge_map = {} # Source -> [Targets]
|
74 |
+
for edge in edges:
|
75 |
+
edge_map.setdefault(edge["source"], [])
|
76 |
+
edge_map[edge["source"]].append(edge["target"])
|
77 |
+
new_edges = []
|
78 |
+
for edge in edges:
|
79 |
+
for t in edge_map.get(edge["target"], []):
|
80 |
+
new_edges.append({"source": edge["source"], "target": t})
|
81 |
+
return edges + new_edges
|
82 |
+
|
83 |
|
84 |
@ops.input_position(nodes="top", edges="top")
|
85 |
@op("Add neighbors")
|
86 |
def add_neighbors(nodes, edges, item):
|
87 |
+
nodes = pd.DataFrame(nodes)
|
88 |
+
edges = pd.DataFrame(edges)
|
89 |
+
matches = item["rag"]
|
90 |
+
additional_matches = []
|
91 |
+
for m in matches:
|
92 |
+
node = nodes[nodes["text"] == m].index[0]
|
93 |
+
neighbors = edges[edges["source"] == node]["target"].to_list()
|
94 |
+
additional_matches.extend(nodes.loc[neighbors, "text"])
|
95 |
+
return {**item, "rag": matches + additional_matches}
|
96 |
+
|
97 |
|
98 |
@op("Create prompt")
|
99 |
+
def create_prompt(input, *, save_as="prompt", template: ops.LongStr):
|
100 |
+
assert (
|
101 |
+
template
|
102 |
+
), "Please specify the template. Refer to columns using the Jinja2 syntax."
|
103 |
+
t = jinja.from_string(template)
|
104 |
+
prompt = t.render(**input)
|
105 |
+
return {**input, save_as: prompt}
|
106 |
+
|
107 |
|
108 |
@op("Ask LLM")
|
109 |
def ask_llm(input, *, model: str, accepted_regex: str = None, max_tokens: int = 100):
|
110 |
+
assert model, "Please specify the model."
|
111 |
+
assert "prompt" in input, "Please create the prompt first."
|
112 |
+
options = {}
|
113 |
+
if accepted_regex:
|
114 |
+
options["extra_body"] = {
|
115 |
+
"guided_regex": accepted_regex,
|
116 |
+
}
|
117 |
+
results = chat(
|
118 |
+
model=model,
|
119 |
+
max_tokens=max_tokens,
|
120 |
+
messages=[
|
121 |
+
{"role": "user", "content": input["prompt"]},
|
122 |
+
],
|
123 |
+
**options,
|
124 |
+
)
|
125 |
+
return [{**input, "response": r} for r in results]
|
126 |
+
|
127 |
|
128 |
@op("View", view="table_view")
|
129 |
def view(input, *, _ctx: one_by_one.Context):
|
130 |
+
v = _ctx.last_result
|
131 |
+
if v:
|
132 |
+
columns = v["dataframes"]["df"]["columns"]
|
133 |
+
v["dataframes"]["df"]["data"].append([input[c] for c in columns])
|
134 |
+
else:
|
135 |
+
columns = [str(c) for c in input.keys() if not str(c).startswith("_")]
|
136 |
+
v = {
|
137 |
+
"dataframes": {
|
138 |
+
"df": {
|
139 |
+
"columns": columns,
|
140 |
+
"data": [[input[c] for c in columns]],
|
141 |
+
}
|
142 |
+
}
|
143 |
+
}
|
144 |
+
return v
|
145 |
+
|
146 |
|
147 |
@ops.input_position(input="right")
|
148 |
@ops.output_position(output="left")
|
149 |
@op("Loop")
|
150 |
def loop(input, *, max_iterations: int = 3, _ctx: one_by_one.Context):
|
151 |
+
"""Data can flow back here max_iterations-1 times."""
|
152 |
+
key = f"iterations-{_ctx.node.id}"
|
153 |
+
input[key] = input.get(key, 0) + 1
|
154 |
+
if input[key] < max_iterations:
|
155 |
+
return input
|
156 |
+
|
157 |
|
158 |
+
@op("Branch", outputs=["true", "false"])
|
159 |
def branch(input, *, expression: str):
|
160 |
+
res = eval(expression, input)
|
161 |
+
return one_by_one.Output(output_handle=str(bool(res)).lower(), value=input)
|
162 |
+
|
163 |
|
164 |
class RagEngine(enum.Enum):
|
165 |
+
Chroma = "Chroma"
|
166 |
+
Custom = "Custom"
|
167 |
+
|
168 |
|
169 |
@ops.input_position(db="top")
|
170 |
+
@op("RAG")
|
171 |
def rag(
|
172 |
+
input,
|
173 |
+
db,
|
174 |
+
*,
|
175 |
+
engine: RagEngine = RagEngine.Chroma,
|
176 |
+
input_field="text",
|
177 |
+
db_field="text",
|
178 |
+
num_matches: int = 10,
|
179 |
+
_ctx: one_by_one.Context,
|
180 |
+
):
|
181 |
+
if engine == RagEngine.Chroma:
|
182 |
+
last = _ctx.last_result
|
183 |
+
if last:
|
184 |
+
collection = last["_collection"]
|
185 |
+
else:
|
186 |
+
collection_name = _ctx.node.id.replace(" ", "_")
|
187 |
+
for c in chroma_client.list_collections():
|
188 |
+
if c.name == collection_name:
|
189 |
+
chroma_client.delete_collection(name=collection_name)
|
190 |
+
collection = chroma_client.create_collection(name=collection_name)
|
191 |
+
collection.add(
|
192 |
+
documents=[r[db_field] for r in db],
|
193 |
+
ids=[str(i) for i in range(len(db))],
|
194 |
+
)
|
195 |
+
results = collection.query(
|
196 |
+
query_texts=[input[input_field]],
|
197 |
+
n_results=num_matches,
|
198 |
+
)
|
199 |
+
results = [db[int(r)] for r in results["ids"][0]]
|
200 |
+
return {**input, "rag": results, "_collection": collection}
|
201 |
+
if engine == RagEngine.Custom:
|
202 |
+
model = "google/gemma-2-2b-it"
|
203 |
+
chat = input[input_field]
|
204 |
+
embeddings = [embedding(input=[r[db_field]], model=model) for r in db]
|
205 |
+
q = embedding(input=[chat], model=model)
|
206 |
+
|
207 |
+
def cosine_similarity(a, b):
|
208 |
+
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
|
209 |
+
|
210 |
+
scores = [(i, cosine_similarity(q, e)) for i, e in enumerate(embeddings)]
|
211 |
+
scores.sort(key=lambda x: -x[1])
|
212 |
+
matches = [db[i][db_field] for i, _ in scores[:num_matches]]
|
213 |
+
return {**input, "rag": matches}
|
214 |
+
|
215 |
+
|
216 |
+
@op("Run Python")
|
217 |
def run_python(input, *, template: str):
|
218 |
+
"""TODO: Implement."""
|
219 |
+
return input
|
web/src/workspace/nodes/LynxKiteNode.tsx
CHANGED
@@ -1,5 +1,3 @@
|
|
1 |
-
import { useContext } from 'react';
|
2 |
-
import { LynxKiteState } from '../LynxKiteState';
|
3 |
import { useReactFlow, Handle, NodeResizeControl, Position } from '@xyflow/react';
|
4 |
// @ts-ignore
|
5 |
import ChevronDownRight from '~icons/tabler/chevron-down-right.jsx';
|
@@ -45,7 +43,6 @@ function getHandles(inputs: object, outputs: object) {
|
|
45 |
export default function LynxKiteNode(props: LynxKiteNodeProps) {
|
46 |
const reactFlow = useReactFlow();
|
47 |
const data = props.data;
|
48 |
-
const state = useContext(LynxKiteState);
|
49 |
const expanded = !data.collapsed;
|
50 |
const handles = getHandles(data.meta?.inputs || {}, data.meta?.outputs || {});
|
51 |
function asPx(n: number | undefined) {
|
|
|
|
|
|
|
1 |
import { useReactFlow, Handle, NodeResizeControl, Position } from '@xyflow/react';
|
2 |
// @ts-ignore
|
3 |
import ChevronDownRight from '~icons/tabler/chevron-down-right.jsx';
|
|
|
43 |
export default function LynxKiteNode(props: LynxKiteNodeProps) {
|
44 |
const reactFlow = useReactFlow();
|
45 |
const data = props.data;
|
|
|
46 |
const expanded = !data.collapsed;
|
47 |
const handles = getHandles(data.meta?.inputs || {}, data.meta?.outputs || {});
|
48 |
function asPx(n: number | undefined) {
|