Spaces:
Running
Running
Add RAG, batch inputs, caching.
Browse files- requirements.txt +2 -0
- server/llm_ops.py +118 -55
- server/ops.py +0 -3
- server/test_llm_ops.py +36 -14
- server/workspace.py +0 -1
requirements.txt
CHANGED
|
@@ -6,4 +6,6 @@ pandas
|
|
| 6 |
scipy
|
| 7 |
uvicorn[standard]
|
| 8 |
# For llm_ops
|
|
|
|
|
|
|
| 9 |
openai
|
|
|
|
| 6 |
scipy
|
| 7 |
uvicorn[standard]
|
| 8 |
# For llm_ops
|
| 9 |
+
chromadb
|
| 10 |
+
Jinja2
|
| 11 |
openai
|
server/llm_ops.py
CHANGED
|
@@ -1,48 +1,50 @@
|
|
| 1 |
'''For specifying an LLM agent logic flow.'''
|
| 2 |
from . import ops
|
| 3 |
-
import
|
|
|
|
| 4 |
import inspect
|
|
|
|
| 5 |
import json
|
| 6 |
import openai
|
| 7 |
import pandas as pd
|
| 8 |
import traceback
|
|
|
|
| 9 |
from . import workspace
|
| 10 |
|
| 11 |
client = openai.OpenAI(base_url="http://localhost:11434/v1")
|
| 12 |
-
|
|
|
|
|
|
|
| 13 |
ENV = 'LLM logic'
|
| 14 |
op = ops.op_registration(ENV)
|
| 15 |
|
| 16 |
-
|
| 17 |
-
class Context:
|
| 18 |
'''Passed to operation functions as "_ctx" if they have such a parameter.'''
|
| 19 |
node: workspace.WorkspaceNode
|
| 20 |
-
last_result = None
|
| 21 |
|
| 22 |
-
|
| 23 |
-
class Output:
|
| 24 |
'''Return this to send values to specific outputs of a node.'''
|
| 25 |
output_handle: str
|
| 26 |
value: dict
|
| 27 |
|
| 28 |
def chat(*args, **kwargs):
|
| 29 |
key = json.dumps({'args': args, 'kwargs': kwargs})
|
| 30 |
-
if key not in
|
| 31 |
completion = client.chat.completions.create(*args, **kwargs)
|
| 32 |
-
|
| 33 |
-
return
|
| 34 |
|
| 35 |
@op("Input")
|
| 36 |
def input(*, filename: ops.PathStr, key: str):
|
| 37 |
return pd.read_csv(filename).rename(columns={key: 'text'})
|
| 38 |
|
| 39 |
@op("Create prompt")
|
| 40 |
-
def create_prompt(input, *, template: ops.LongStr):
|
| 41 |
-
assert template, 'Please specify the template. Refer to columns using
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
return p
|
| 46 |
|
| 47 |
@op("Ask LLM")
|
| 48 |
def ask_llm(input, *, model: str, accepted_regex: str = None, max_tokens: int = 100):
|
|
@@ -74,7 +76,7 @@ def view(input, *, _ctx: Context):
|
|
| 74 |
v = {
|
| 75 |
'dataframes': { 'df': {
|
| 76 |
'columns': columns,
|
| 77 |
-
'data': [input[c] for c in columns],
|
| 78 |
}}
|
| 79 |
}
|
| 80 |
return v
|
|
@@ -92,12 +94,30 @@ def loop(input, *, max_iterations: int = 3, _ctx: Context):
|
|
| 92 |
@op('Branch', outputs=['true', 'false'])
|
| 93 |
def branch(input, *, expression: str):
|
| 94 |
res = eval(expression, input)
|
| 95 |
-
return Output(str(bool(res)).lower(), input)
|
| 96 |
|
| 97 |
@ops.input_position(db="top")
|
| 98 |
@op('RAG')
|
| 99 |
-
def rag(input, db, *,
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
@op('Run Python')
|
| 103 |
def run_python(input, *, template: str):
|
|
@@ -107,16 +127,16 @@ def run_python(input, *, template: str):
|
|
| 107 |
p = p.replace(k.upper(), str(v))
|
| 108 |
return p
|
| 109 |
|
| 110 |
-
|
| 111 |
|
| 112 |
@ops.register_executor(ENV)
|
| 113 |
def execute(ws):
|
| 114 |
catalog = ops.CATALOGS[ENV]
|
| 115 |
nodes = {n.id: n for n in ws.nodes}
|
| 116 |
-
contexts = {n.id: Context(n) for n in ws.nodes}
|
| 117 |
edges = {n.id: [] for n in ws.nodes}
|
| 118 |
for e in ws.edges:
|
| 119 |
-
edges[e.source].append(e
|
| 120 |
tasks = {}
|
| 121 |
NO_INPUT = object() # Marker for initial tasks.
|
| 122 |
for node in ws.nodes:
|
|
@@ -125,39 +145,54 @@ def execute(ws):
|
|
| 125 |
# Start tasks for nodes that have no inputs.
|
| 126 |
if not op.inputs:
|
| 127 |
tasks[node.id] = [NO_INPUT]
|
|
|
|
| 128 |
# Run the rest until we run out of tasks.
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
|
| 162 |
def df_to_list(df):
|
| 163 |
return [dict(zip(df.columns, row)) for row in df.values]
|
|
@@ -165,3 +200,31 @@ def df_to_list(df):
|
|
| 165 |
def has_ctx(op):
|
| 166 |
sig = inspect.signature(op.func)
|
| 167 |
return '_ctx' in sig.parameters
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
'''For specifying an LLM agent logic flow.'''
|
| 2 |
from . import ops
|
| 3 |
+
import chromadb
|
| 4 |
+
import fastapi.encoders
|
| 5 |
import inspect
|
| 6 |
+
import jinja2
|
| 7 |
import json
|
| 8 |
import openai
|
| 9 |
import pandas as pd
|
| 10 |
import traceback
|
| 11 |
+
import typing
|
| 12 |
from . import workspace
|
| 13 |
|
| 14 |
client = openai.OpenAI(base_url="http://localhost:11434/v1")
|
| 15 |
+
jinja = jinja2.Environment()
|
| 16 |
+
chroma_client = chromadb.Client()
|
| 17 |
+
LLM_CACHE = {}
|
| 18 |
ENV = 'LLM logic'
|
| 19 |
op = ops.op_registration(ENV)
|
| 20 |
|
| 21 |
+
class Context(ops.BaseConfig):
|
|
|
|
| 22 |
'''Passed to operation functions as "_ctx" if they have such a parameter.'''
|
| 23 |
node: workspace.WorkspaceNode
|
| 24 |
+
last_result: typing.Any = None
|
| 25 |
|
| 26 |
+
class Output(ops.BaseConfig):
|
|
|
|
| 27 |
'''Return this to send values to specific outputs of a node.'''
|
| 28 |
output_handle: str
|
| 29 |
value: dict
|
| 30 |
|
| 31 |
def chat(*args, **kwargs):
|
| 32 |
key = json.dumps({'args': args, 'kwargs': kwargs})
|
| 33 |
+
if key not in LLM_CACHE:
|
| 34 |
completion = client.chat.completions.create(*args, **kwargs)
|
| 35 |
+
LLM_CACHE[key] = [c.message.content for c in completion.choices]
|
| 36 |
+
return LLM_CACHE[key]
|
| 37 |
|
| 38 |
@op("Input")
|
| 39 |
def input(*, filename: ops.PathStr, key: str):
|
| 40 |
return pd.read_csv(filename).rename(columns={key: 'text'})
|
| 41 |
|
| 42 |
@op("Create prompt")
|
| 43 |
+
def create_prompt(input, *, save_as='prompt', template: ops.LongStr):
|
| 44 |
+
assert template, 'Please specify the template. Refer to columns using the Jinja2 syntax.'
|
| 45 |
+
t = jinja.from_string(template)
|
| 46 |
+
prompt = t.render(**input)
|
| 47 |
+
return {**input, save_as: prompt}
|
|
|
|
| 48 |
|
| 49 |
@op("Ask LLM")
|
| 50 |
def ask_llm(input, *, model: str, accepted_regex: str = None, max_tokens: int = 100):
|
|
|
|
| 76 |
v = {
|
| 77 |
'dataframes': { 'df': {
|
| 78 |
'columns': columns,
|
| 79 |
+
'data': [[input[c] for c in columns]],
|
| 80 |
}}
|
| 81 |
}
|
| 82 |
return v
|
|
|
|
| 94 |
@op('Branch', outputs=['true', 'false'])
|
| 95 |
def branch(input, *, expression: str):
|
| 96 |
res = eval(expression, input)
|
| 97 |
+
return Output(output_handle=str(bool(res)).lower(), value=input)
|
| 98 |
|
| 99 |
@ops.input_position(db="top")
|
| 100 |
@op('RAG')
|
| 101 |
+
def rag(input, db, *, input_field='text', db_field='text', num_matches: int=10, _ctx: Context):
|
| 102 |
+
last = _ctx.last_result
|
| 103 |
+
if last:
|
| 104 |
+
collection = last['_collection']
|
| 105 |
+
else:
|
| 106 |
+
collection_name = _ctx.node.id.replace(' ', '_')
|
| 107 |
+
for c in chroma_client.list_collections():
|
| 108 |
+
if c.name == collection_name:
|
| 109 |
+
chroma_client.delete_collection(name=collection_name)
|
| 110 |
+
collection = chroma_client.create_collection(name=collection_name)
|
| 111 |
+
collection.add(
|
| 112 |
+
documents=[r[db_field] for r in db],
|
| 113 |
+
ids=[str(i) for i in range(len(db))],
|
| 114 |
+
)
|
| 115 |
+
results = collection.query(
|
| 116 |
+
query_texts=[input[input_field]],
|
| 117 |
+
n_results=num_matches,
|
| 118 |
+
)
|
| 119 |
+
results = [db[int(r)] for r in results['ids'][0]]
|
| 120 |
+
return {**input, 'rag': results, '_collection': collection}
|
| 121 |
|
| 122 |
@op('Run Python')
|
| 123 |
def run_python(input, *, template: str):
|
|
|
|
| 127 |
p = p.replace(k.upper(), str(v))
|
| 128 |
return p
|
| 129 |
|
| 130 |
+
EXECUTOR_OUTPUT_CACHE = {}
|
| 131 |
|
| 132 |
@ops.register_executor(ENV)
|
| 133 |
def execute(ws):
|
| 134 |
catalog = ops.CATALOGS[ENV]
|
| 135 |
nodes = {n.id: n for n in ws.nodes}
|
| 136 |
+
contexts = {n.id: Context(node=n) for n in ws.nodes}
|
| 137 |
edges = {n.id: [] for n in ws.nodes}
|
| 138 |
for e in ws.edges:
|
| 139 |
+
edges[e.source].append(e)
|
| 140 |
tasks = {}
|
| 141 |
NO_INPUT = object() # Marker for initial tasks.
|
| 142 |
for node in ws.nodes:
|
|
|
|
| 145 |
# Start tasks for nodes that have no inputs.
|
| 146 |
if not op.inputs:
|
| 147 |
tasks[node.id] = [NO_INPUT]
|
| 148 |
+
batch_inputs = {}
|
| 149 |
# Run the rest until we run out of tasks.
|
| 150 |
+
for stage in get_stages(ws):
|
| 151 |
+
next_stage = {}
|
| 152 |
+
while tasks:
|
| 153 |
+
n, ts = tasks.popitem()
|
| 154 |
+
if n not in stage:
|
| 155 |
+
next_stage.setdefault(n, []).extend(ts)
|
| 156 |
+
continue
|
| 157 |
+
node = nodes[n]
|
| 158 |
+
data = node.data
|
| 159 |
+
op = catalog[data.title]
|
| 160 |
+
params = {**data.params}
|
| 161 |
+
if has_ctx(op):
|
| 162 |
+
params['_ctx'] = contexts[node.id]
|
| 163 |
+
results = []
|
| 164 |
+
for task in ts:
|
| 165 |
+
try:
|
| 166 |
+
inputs = [
|
| 167 |
+
batch_inputs[(n, i.name)] if i.position == 'top' else task
|
| 168 |
+
for i in op.inputs.values()]
|
| 169 |
+
key = json.dumps(fastapi.encoders.jsonable_encoder((inputs, params)))
|
| 170 |
+
if key not in EXECUTOR_OUTPUT_CACHE:
|
| 171 |
+
EXECUTOR_OUTPUT_CACHE[key] = op.func(*inputs, **params)
|
| 172 |
+
result = EXECUTOR_OUTPUT_CACHE[key]
|
| 173 |
+
except Exception as e:
|
| 174 |
+
traceback.print_exc()
|
| 175 |
+
data.error = str(e)
|
| 176 |
+
break
|
| 177 |
+
contexts[node.id].last_result = result
|
| 178 |
+
# Returned lists and DataFrames are considered multiple tasks.
|
| 179 |
+
if isinstance(result, pd.DataFrame):
|
| 180 |
+
result = df_to_list(result)
|
| 181 |
+
elif not isinstance(result, list):
|
| 182 |
+
result = [result]
|
| 183 |
+
results.extend(result)
|
| 184 |
+
else: # Finished all tasks without errors.
|
| 185 |
+
if op.type == 'visualization' or op.type == 'table_view':
|
| 186 |
+
data.display = results[0]
|
| 187 |
+
for edge in edges[node.id]:
|
| 188 |
+
t = nodes[edge.target]
|
| 189 |
+
op = catalog[t.data.title]
|
| 190 |
+
i = op.inputs[edge.targetHandle]
|
| 191 |
+
if i.position == 'top':
|
| 192 |
+
batch_inputs.setdefault((edge.target, edge.targetHandle), []).extend(results)
|
| 193 |
+
else:
|
| 194 |
+
tasks.setdefault(edge.target, []).extend(results)
|
| 195 |
+
tasks = next_stage
|
| 196 |
|
| 197 |
def df_to_list(df):
|
| 198 |
return [dict(zip(df.columns, row)) for row in df.values]
|
|
|
|
| 200 |
def has_ctx(op):
|
| 201 |
sig = inspect.signature(op.func)
|
| 202 |
return '_ctx' in sig.parameters
|
| 203 |
+
|
| 204 |
+
def get_stages(ws):
|
| 205 |
+
'''Inputs on top are batch inputs. We decompose the graph into a DAG of components along these edges.'''
|
| 206 |
+
catalog = ops.CATALOGS[ENV]
|
| 207 |
+
nodes = {n.id: n for n in ws.nodes}
|
| 208 |
+
batch_inputs = {}
|
| 209 |
+
inputs = {}
|
| 210 |
+
for edge in ws.edges:
|
| 211 |
+
inputs.setdefault(edge.target, []).append(edge.source)
|
| 212 |
+
node = nodes[edge.target]
|
| 213 |
+
op = catalog[node.data.title]
|
| 214 |
+
i = op.inputs[edge.targetHandle]
|
| 215 |
+
if i.position == 'top':
|
| 216 |
+
batch_inputs.setdefault(edge.target, []).append(edge.source)
|
| 217 |
+
stages = []
|
| 218 |
+
for bt, bss in batch_inputs.items():
|
| 219 |
+
upstream = set(bss)
|
| 220 |
+
new = set(bss)
|
| 221 |
+
while new:
|
| 222 |
+
n = new.pop()
|
| 223 |
+
for i in inputs.get(n, []):
|
| 224 |
+
if i not in upstream:
|
| 225 |
+
upstream.add(i)
|
| 226 |
+
new.add(i)
|
| 227 |
+
stages.append(upstream)
|
| 228 |
+
stages.sort(key=lambda s: len(s))
|
| 229 |
+
stages.append(set(nodes))
|
| 230 |
+
return stages
|
server/ops.py
CHANGED
|
@@ -1,11 +1,8 @@
|
|
| 1 |
'''API for implementing LynxKite operations.'''
|
| 2 |
from __future__ import annotations
|
| 3 |
-
import dataclasses
|
| 4 |
import enum
|
| 5 |
import functools
|
| 6 |
import inspect
|
| 7 |
-
import networkx as nx
|
| 8 |
-
import pandas as pd
|
| 9 |
import pydantic
|
| 10 |
import typing
|
| 11 |
from typing_extensions import Annotated
|
|
|
|
| 1 |
'''API for implementing LynxKite operations.'''
|
| 2 |
from __future__ import annotations
|
|
|
|
| 3 |
import enum
|
| 4 |
import functools
|
| 5 |
import inspect
|
|
|
|
|
|
|
| 6 |
import pydantic
|
| 7 |
import typing
|
| 8 |
from typing_extensions import Annotated
|
server/test_llm_ops.py
CHANGED
|
@@ -2,27 +2,49 @@ import unittest
|
|
| 2 |
from . import llm_ops
|
| 3 |
from . import workspace
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
class LLMOpsTest(unittest.TestCase):
|
| 6 |
def testExecute(self):
|
| 7 |
ws = workspace.Workspace(env='LLM logic', nodes=[
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
'key': 'problem',
|
| 15 |
-
})),
|
| 16 |
-
workspace.WorkspaceNode(
|
| 17 |
-
id='1',
|
| 18 |
-
type='table_view',
|
| 19 |
-
position=workspace.Position(x=0, y=0),
|
| 20 |
-
data=workspace.WorkspaceNodeData(title='View', params={})),
|
| 21 |
], edges=[
|
| 22 |
-
|
| 23 |
])
|
| 24 |
llm_ops.execute(ws)
|
| 25 |
self.assertEqual('', ws.nodes[1].data.display)
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
if __name__ == '__main__':
|
| 28 |
unittest.main()
|
|
|
|
| 2 |
from . import llm_ops
|
| 3 |
from . import workspace
|
| 4 |
|
| 5 |
+
def make_node(id, op, type='basic', **params):
|
| 6 |
+
return workspace.WorkspaceNode(
|
| 7 |
+
id=id,
|
| 8 |
+
type=type,
|
| 9 |
+
position=workspace.Position(x=0, y=0),
|
| 10 |
+
data=workspace.WorkspaceNodeData(title=op, params=params),
|
| 11 |
+
)
|
| 12 |
+
def make_input(id):
|
| 13 |
+
return make_node(
|
| 14 |
+
id, 'Input',
|
| 15 |
+
filename='/Users/danieldarabos/Downloads/aimo-train.csv',
|
| 16 |
+
key='problem')
|
| 17 |
+
def make_edge(source, target, targetHandle='input'):
|
| 18 |
+
return workspace.WorkspaceEdge(
|
| 19 |
+
id=f'{source}-{target}', source=source, target=target, sourceHandle='', targetHandle=targetHandle)
|
| 20 |
+
|
| 21 |
class LLMOpsTest(unittest.TestCase):
|
| 22 |
def testExecute(self):
|
| 23 |
ws = workspace.Workspace(env='LLM logic', nodes=[
|
| 24 |
+
make_node(
|
| 25 |
+
'0', 'Input',
|
| 26 |
+
filename='/Users/danieldarabos/Downloads/aimo-train.csv',
|
| 27 |
+
key='problem'),
|
| 28 |
+
make_node(
|
| 29 |
+
'1', 'View', type='table_view'),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
], edges=[
|
| 31 |
+
make_edge('0', '1')
|
| 32 |
])
|
| 33 |
llm_ops.execute(ws)
|
| 34 |
self.assertEqual('', ws.nodes[1].data.display)
|
| 35 |
|
| 36 |
+
def testStages(self):
|
| 37 |
+
ws = workspace.Workspace(env='LLM logic', nodes=[
|
| 38 |
+
make_input('in1'), make_input('in2'), make_input('in3'),
|
| 39 |
+
make_node('rag1', 'RAG'), make_node('rag2', 'RAG'),
|
| 40 |
+
make_node('p1', 'Create prompt'), make_node('p2', 'Create prompt'),
|
| 41 |
+
], edges=[
|
| 42 |
+
make_edge('in1', 'rag1', 'db'), make_edge('in2', 'rag1'),
|
| 43 |
+
make_edge('rag1', 'p1'), make_edge('p1', 'rag2', 'db'),
|
| 44 |
+
make_edge('in3', 'p2'), make_edge('p3', 'rag2'),
|
| 45 |
+
])
|
| 46 |
+
stages = llm_ops.get_stages(ws)
|
| 47 |
+
self.assertEqual('', stages)
|
| 48 |
+
|
| 49 |
if __name__ == '__main__':
|
| 50 |
unittest.main()
|
server/workspace.py
CHANGED
|
@@ -4,7 +4,6 @@ import dataclasses
|
|
| 4 |
import os
|
| 5 |
import pydantic
|
| 6 |
import tempfile
|
| 7 |
-
import traceback
|
| 8 |
from . import ops
|
| 9 |
|
| 10 |
class BaseConfig(pydantic.BaseModel):
|
|
|
|
| 4 |
import os
|
| 5 |
import pydantic
|
| 6 |
import tempfile
|
|
|
|
| 7 |
from . import ops
|
| 8 |
|
| 9 |
class BaseConfig(pydantic.BaseModel):
|