Spaces:
Running
Running
Add RAG, batch inputs, caching.
Browse files- requirements.txt +2 -0
- server/llm_ops.py +118 -55
- server/ops.py +0 -3
- server/test_llm_ops.py +36 -14
- server/workspace.py +0 -1
requirements.txt
CHANGED
@@ -6,4 +6,6 @@ pandas
|
|
6 |
scipy
|
7 |
uvicorn[standard]
|
8 |
# For llm_ops
|
|
|
|
|
9 |
openai
|
|
|
6 |
scipy
|
7 |
uvicorn[standard]
|
8 |
# For llm_ops
|
9 |
+
chromadb
|
10 |
+
Jinja2
|
11 |
openai
|
server/llm_ops.py
CHANGED
@@ -1,48 +1,50 @@
|
|
1 |
'''For specifying an LLM agent logic flow.'''
|
2 |
from . import ops
|
3 |
-
import
|
|
|
4 |
import inspect
|
|
|
5 |
import json
|
6 |
import openai
|
7 |
import pandas as pd
|
8 |
import traceback
|
|
|
9 |
from . import workspace
|
10 |
|
11 |
client = openai.OpenAI(base_url="http://localhost:11434/v1")
|
12 |
-
|
|
|
|
|
13 |
ENV = 'LLM logic'
|
14 |
op = ops.op_registration(ENV)
|
15 |
|
16 |
-
|
17 |
-
class Context:
|
18 |
'''Passed to operation functions as "_ctx" if they have such a parameter.'''
|
19 |
node: workspace.WorkspaceNode
|
20 |
-
last_result = None
|
21 |
|
22 |
-
|
23 |
-
class Output:
|
24 |
'''Return this to send values to specific outputs of a node.'''
|
25 |
output_handle: str
|
26 |
value: dict
|
27 |
|
28 |
def chat(*args, **kwargs):
|
29 |
key = json.dumps({'args': args, 'kwargs': kwargs})
|
30 |
-
if key not in
|
31 |
completion = client.chat.completions.create(*args, **kwargs)
|
32 |
-
|
33 |
-
return
|
34 |
|
35 |
@op("Input")
|
36 |
def input(*, filename: ops.PathStr, key: str):
|
37 |
return pd.read_csv(filename).rename(columns={key: 'text'})
|
38 |
|
39 |
@op("Create prompt")
|
40 |
-
def create_prompt(input, *, template: ops.LongStr):
|
41 |
-
assert template, 'Please specify the template. Refer to columns using
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
return p
|
46 |
|
47 |
@op("Ask LLM")
|
48 |
def ask_llm(input, *, model: str, accepted_regex: str = None, max_tokens: int = 100):
|
@@ -74,7 +76,7 @@ def view(input, *, _ctx: Context):
|
|
74 |
v = {
|
75 |
'dataframes': { 'df': {
|
76 |
'columns': columns,
|
77 |
-
'data': [input[c] for c in columns],
|
78 |
}}
|
79 |
}
|
80 |
return v
|
@@ -92,12 +94,30 @@ def loop(input, *, max_iterations: int = 3, _ctx: Context):
|
|
92 |
@op('Branch', outputs=['true', 'false'])
|
93 |
def branch(input, *, expression: str):
|
94 |
res = eval(expression, input)
|
95 |
-
return Output(str(bool(res)).lower(), input)
|
96 |
|
97 |
@ops.input_position(db="top")
|
98 |
@op('RAG')
|
99 |
-
def rag(input, db, *,
|
100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
|
102 |
@op('Run Python')
|
103 |
def run_python(input, *, template: str):
|
@@ -107,16 +127,16 @@ def run_python(input, *, template: str):
|
|
107 |
p = p.replace(k.upper(), str(v))
|
108 |
return p
|
109 |
|
110 |
-
|
111 |
|
112 |
@ops.register_executor(ENV)
|
113 |
def execute(ws):
|
114 |
catalog = ops.CATALOGS[ENV]
|
115 |
nodes = {n.id: n for n in ws.nodes}
|
116 |
-
contexts = {n.id: Context(n) for n in ws.nodes}
|
117 |
edges = {n.id: [] for n in ws.nodes}
|
118 |
for e in ws.edges:
|
119 |
-
edges[e.source].append(e
|
120 |
tasks = {}
|
121 |
NO_INPUT = object() # Marker for initial tasks.
|
122 |
for node in ws.nodes:
|
@@ -125,39 +145,54 @@ def execute(ws):
|
|
125 |
# Start tasks for nodes that have no inputs.
|
126 |
if not op.inputs:
|
127 |
tasks[node.id] = [NO_INPUT]
|
|
|
128 |
# Run the rest until we run out of tasks.
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
|
162 |
def df_to_list(df):
|
163 |
return [dict(zip(df.columns, row)) for row in df.values]
|
@@ -165,3 +200,31 @@ def df_to_list(df):
|
|
165 |
def has_ctx(op):
|
166 |
sig = inspect.signature(op.func)
|
167 |
return '_ctx' in sig.parameters
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
'''For specifying an LLM agent logic flow.'''
|
2 |
from . import ops
|
3 |
+
import chromadb
|
4 |
+
import fastapi.encoders
|
5 |
import inspect
|
6 |
+
import jinja2
|
7 |
import json
|
8 |
import openai
|
9 |
import pandas as pd
|
10 |
import traceback
|
11 |
+
import typing
|
12 |
from . import workspace
|
13 |
|
14 |
client = openai.OpenAI(base_url="http://localhost:11434/v1")
|
15 |
+
jinja = jinja2.Environment()
|
16 |
+
chroma_client = chromadb.Client()
|
17 |
+
LLM_CACHE = {}
|
18 |
ENV = 'LLM logic'
|
19 |
op = ops.op_registration(ENV)
|
20 |
|
21 |
+
class Context(ops.BaseConfig):
|
|
|
22 |
'''Passed to operation functions as "_ctx" if they have such a parameter.'''
|
23 |
node: workspace.WorkspaceNode
|
24 |
+
last_result: typing.Any = None
|
25 |
|
26 |
+
class Output(ops.BaseConfig):
|
|
|
27 |
'''Return this to send values to specific outputs of a node.'''
|
28 |
output_handle: str
|
29 |
value: dict
|
30 |
|
31 |
def chat(*args, **kwargs):
|
32 |
key = json.dumps({'args': args, 'kwargs': kwargs})
|
33 |
+
if key not in LLM_CACHE:
|
34 |
completion = client.chat.completions.create(*args, **kwargs)
|
35 |
+
LLM_CACHE[key] = [c.message.content for c in completion.choices]
|
36 |
+
return LLM_CACHE[key]
|
37 |
|
38 |
@op("Input")
|
39 |
def input(*, filename: ops.PathStr, key: str):
|
40 |
return pd.read_csv(filename).rename(columns={key: 'text'})
|
41 |
|
42 |
@op("Create prompt")
|
43 |
+
def create_prompt(input, *, save_as='prompt', template: ops.LongStr):
|
44 |
+
assert template, 'Please specify the template. Refer to columns using the Jinja2 syntax.'
|
45 |
+
t = jinja.from_string(template)
|
46 |
+
prompt = t.render(**input)
|
47 |
+
return {**input, save_as: prompt}
|
|
|
48 |
|
49 |
@op("Ask LLM")
|
50 |
def ask_llm(input, *, model: str, accepted_regex: str = None, max_tokens: int = 100):
|
|
|
76 |
v = {
|
77 |
'dataframes': { 'df': {
|
78 |
'columns': columns,
|
79 |
+
'data': [[input[c] for c in columns]],
|
80 |
}}
|
81 |
}
|
82 |
return v
|
|
|
94 |
@op('Branch', outputs=['true', 'false'])
|
95 |
def branch(input, *, expression: str):
|
96 |
res = eval(expression, input)
|
97 |
+
return Output(output_handle=str(bool(res)).lower(), value=input)
|
98 |
|
99 |
@ops.input_position(db="top")
|
100 |
@op('RAG')
|
101 |
+
def rag(input, db, *, input_field='text', db_field='text', num_matches: int=10, _ctx: Context):
|
102 |
+
last = _ctx.last_result
|
103 |
+
if last:
|
104 |
+
collection = last['_collection']
|
105 |
+
else:
|
106 |
+
collection_name = _ctx.node.id.replace(' ', '_')
|
107 |
+
for c in chroma_client.list_collections():
|
108 |
+
if c.name == collection_name:
|
109 |
+
chroma_client.delete_collection(name=collection_name)
|
110 |
+
collection = chroma_client.create_collection(name=collection_name)
|
111 |
+
collection.add(
|
112 |
+
documents=[r[db_field] for r in db],
|
113 |
+
ids=[str(i) for i in range(len(db))],
|
114 |
+
)
|
115 |
+
results = collection.query(
|
116 |
+
query_texts=[input[input_field]],
|
117 |
+
n_results=num_matches,
|
118 |
+
)
|
119 |
+
results = [db[int(r)] for r in results['ids'][0]]
|
120 |
+
return {**input, 'rag': results, '_collection': collection}
|
121 |
|
122 |
@op('Run Python')
|
123 |
def run_python(input, *, template: str):
|
|
|
127 |
p = p.replace(k.upper(), str(v))
|
128 |
return p
|
129 |
|
130 |
+
EXECUTOR_OUTPUT_CACHE = {}
|
131 |
|
132 |
@ops.register_executor(ENV)
|
133 |
def execute(ws):
|
134 |
catalog = ops.CATALOGS[ENV]
|
135 |
nodes = {n.id: n for n in ws.nodes}
|
136 |
+
contexts = {n.id: Context(node=n) for n in ws.nodes}
|
137 |
edges = {n.id: [] for n in ws.nodes}
|
138 |
for e in ws.edges:
|
139 |
+
edges[e.source].append(e)
|
140 |
tasks = {}
|
141 |
NO_INPUT = object() # Marker for initial tasks.
|
142 |
for node in ws.nodes:
|
|
|
145 |
# Start tasks for nodes that have no inputs.
|
146 |
if not op.inputs:
|
147 |
tasks[node.id] = [NO_INPUT]
|
148 |
+
batch_inputs = {}
|
149 |
# Run the rest until we run out of tasks.
|
150 |
+
for stage in get_stages(ws):
|
151 |
+
next_stage = {}
|
152 |
+
while tasks:
|
153 |
+
n, ts = tasks.popitem()
|
154 |
+
if n not in stage:
|
155 |
+
next_stage.setdefault(n, []).extend(ts)
|
156 |
+
continue
|
157 |
+
node = nodes[n]
|
158 |
+
data = node.data
|
159 |
+
op = catalog[data.title]
|
160 |
+
params = {**data.params}
|
161 |
+
if has_ctx(op):
|
162 |
+
params['_ctx'] = contexts[node.id]
|
163 |
+
results = []
|
164 |
+
for task in ts:
|
165 |
+
try:
|
166 |
+
inputs = [
|
167 |
+
batch_inputs[(n, i.name)] if i.position == 'top' else task
|
168 |
+
for i in op.inputs.values()]
|
169 |
+
key = json.dumps(fastapi.encoders.jsonable_encoder((inputs, params)))
|
170 |
+
if key not in EXECUTOR_OUTPUT_CACHE:
|
171 |
+
EXECUTOR_OUTPUT_CACHE[key] = op.func(*inputs, **params)
|
172 |
+
result = EXECUTOR_OUTPUT_CACHE[key]
|
173 |
+
except Exception as e:
|
174 |
+
traceback.print_exc()
|
175 |
+
data.error = str(e)
|
176 |
+
break
|
177 |
+
contexts[node.id].last_result = result
|
178 |
+
# Returned lists and DataFrames are considered multiple tasks.
|
179 |
+
if isinstance(result, pd.DataFrame):
|
180 |
+
result = df_to_list(result)
|
181 |
+
elif not isinstance(result, list):
|
182 |
+
result = [result]
|
183 |
+
results.extend(result)
|
184 |
+
else: # Finished all tasks without errors.
|
185 |
+
if op.type == 'visualization' or op.type == 'table_view':
|
186 |
+
data.display = results[0]
|
187 |
+
for edge in edges[node.id]:
|
188 |
+
t = nodes[edge.target]
|
189 |
+
op = catalog[t.data.title]
|
190 |
+
i = op.inputs[edge.targetHandle]
|
191 |
+
if i.position == 'top':
|
192 |
+
batch_inputs.setdefault((edge.target, edge.targetHandle), []).extend(results)
|
193 |
+
else:
|
194 |
+
tasks.setdefault(edge.target, []).extend(results)
|
195 |
+
tasks = next_stage
|
196 |
|
197 |
def df_to_list(df):
|
198 |
return [dict(zip(df.columns, row)) for row in df.values]
|
|
|
200 |
def has_ctx(op):
|
201 |
sig = inspect.signature(op.func)
|
202 |
return '_ctx' in sig.parameters
|
203 |
+
|
204 |
+
def get_stages(ws):
|
205 |
+
'''Inputs on top are batch inputs. We decompose the graph into a DAG of components along these edges.'''
|
206 |
+
catalog = ops.CATALOGS[ENV]
|
207 |
+
nodes = {n.id: n for n in ws.nodes}
|
208 |
+
batch_inputs = {}
|
209 |
+
inputs = {}
|
210 |
+
for edge in ws.edges:
|
211 |
+
inputs.setdefault(edge.target, []).append(edge.source)
|
212 |
+
node = nodes[edge.target]
|
213 |
+
op = catalog[node.data.title]
|
214 |
+
i = op.inputs[edge.targetHandle]
|
215 |
+
if i.position == 'top':
|
216 |
+
batch_inputs.setdefault(edge.target, []).append(edge.source)
|
217 |
+
stages = []
|
218 |
+
for bt, bss in batch_inputs.items():
|
219 |
+
upstream = set(bss)
|
220 |
+
new = set(bss)
|
221 |
+
while new:
|
222 |
+
n = new.pop()
|
223 |
+
for i in inputs.get(n, []):
|
224 |
+
if i not in upstream:
|
225 |
+
upstream.add(i)
|
226 |
+
new.add(i)
|
227 |
+
stages.append(upstream)
|
228 |
+
stages.sort(key=lambda s: len(s))
|
229 |
+
stages.append(set(nodes))
|
230 |
+
return stages
|
server/ops.py
CHANGED
@@ -1,11 +1,8 @@
|
|
1 |
'''API for implementing LynxKite operations.'''
|
2 |
from __future__ import annotations
|
3 |
-
import dataclasses
|
4 |
import enum
|
5 |
import functools
|
6 |
import inspect
|
7 |
-
import networkx as nx
|
8 |
-
import pandas as pd
|
9 |
import pydantic
|
10 |
import typing
|
11 |
from typing_extensions import Annotated
|
|
|
1 |
'''API for implementing LynxKite operations.'''
|
2 |
from __future__ import annotations
|
|
|
3 |
import enum
|
4 |
import functools
|
5 |
import inspect
|
|
|
|
|
6 |
import pydantic
|
7 |
import typing
|
8 |
from typing_extensions import Annotated
|
server/test_llm_ops.py
CHANGED
@@ -2,27 +2,49 @@ import unittest
|
|
2 |
from . import llm_ops
|
3 |
from . import workspace
|
4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
class LLMOpsTest(unittest.TestCase):
|
6 |
def testExecute(self):
|
7 |
ws = workspace.Workspace(env='LLM logic', nodes=[
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
'key': 'problem',
|
15 |
-
})),
|
16 |
-
workspace.WorkspaceNode(
|
17 |
-
id='1',
|
18 |
-
type='table_view',
|
19 |
-
position=workspace.Position(x=0, y=0),
|
20 |
-
data=workspace.WorkspaceNodeData(title='View', params={})),
|
21 |
], edges=[
|
22 |
-
|
23 |
])
|
24 |
llm_ops.execute(ws)
|
25 |
self.assertEqual('', ws.nodes[1].data.display)
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
if __name__ == '__main__':
|
28 |
unittest.main()
|
|
|
2 |
from . import llm_ops
|
3 |
from . import workspace
|
4 |
|
5 |
+
def make_node(id, op, type='basic', **params):
|
6 |
+
return workspace.WorkspaceNode(
|
7 |
+
id=id,
|
8 |
+
type=type,
|
9 |
+
position=workspace.Position(x=0, y=0),
|
10 |
+
data=workspace.WorkspaceNodeData(title=op, params=params),
|
11 |
+
)
|
12 |
+
def make_input(id):
|
13 |
+
return make_node(
|
14 |
+
id, 'Input',
|
15 |
+
filename='/Users/danieldarabos/Downloads/aimo-train.csv',
|
16 |
+
key='problem')
|
17 |
+
def make_edge(source, target, targetHandle='input'):
|
18 |
+
return workspace.WorkspaceEdge(
|
19 |
+
id=f'{source}-{target}', source=source, target=target, sourceHandle='', targetHandle=targetHandle)
|
20 |
+
|
21 |
class LLMOpsTest(unittest.TestCase):
|
22 |
def testExecute(self):
|
23 |
ws = workspace.Workspace(env='LLM logic', nodes=[
|
24 |
+
make_node(
|
25 |
+
'0', 'Input',
|
26 |
+
filename='/Users/danieldarabos/Downloads/aimo-train.csv',
|
27 |
+
key='problem'),
|
28 |
+
make_node(
|
29 |
+
'1', 'View', type='table_view'),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
], edges=[
|
31 |
+
make_edge('0', '1')
|
32 |
])
|
33 |
llm_ops.execute(ws)
|
34 |
self.assertEqual('', ws.nodes[1].data.display)
|
35 |
|
36 |
+
def testStages(self):
|
37 |
+
ws = workspace.Workspace(env='LLM logic', nodes=[
|
38 |
+
make_input('in1'), make_input('in2'), make_input('in3'),
|
39 |
+
make_node('rag1', 'RAG'), make_node('rag2', 'RAG'),
|
40 |
+
make_node('p1', 'Create prompt'), make_node('p2', 'Create prompt'),
|
41 |
+
], edges=[
|
42 |
+
make_edge('in1', 'rag1', 'db'), make_edge('in2', 'rag1'),
|
43 |
+
make_edge('rag1', 'p1'), make_edge('p1', 'rag2', 'db'),
|
44 |
+
make_edge('in3', 'p2'), make_edge('p3', 'rag2'),
|
45 |
+
])
|
46 |
+
stages = llm_ops.get_stages(ws)
|
47 |
+
self.assertEqual('', stages)
|
48 |
+
|
49 |
if __name__ == '__main__':
|
50 |
unittest.main()
|
server/workspace.py
CHANGED
@@ -4,7 +4,6 @@ import dataclasses
|
|
4 |
import os
|
5 |
import pydantic
|
6 |
import tempfile
|
7 |
-
import traceback
|
8 |
from . import ops
|
9 |
|
10 |
class BaseConfig(pydantic.BaseModel):
|
|
|
4 |
import os
|
5 |
import pydantic
|
6 |
import tempfile
|
|
|
7 |
from . import ops
|
8 |
|
9 |
class BaseConfig(pydantic.BaseModel):
|