darabos commited on
Commit
6988728
·
1 Parent(s): a07e9cb

Add RAG, batch inputs, caching.

Browse files
requirements.txt CHANGED
@@ -6,4 +6,6 @@ pandas
6
  scipy
7
  uvicorn[standard]
8
  # For llm_ops
 
 
9
  openai
 
6
  scipy
7
  uvicorn[standard]
8
  # For llm_ops
9
+ chromadb
10
+ Jinja2
11
  openai
server/llm_ops.py CHANGED
@@ -1,48 +1,50 @@
1
  '''For specifying an LLM agent logic flow.'''
2
  from . import ops
3
- import dataclasses
 
4
  import inspect
 
5
  import json
6
  import openai
7
  import pandas as pd
8
  import traceback
 
9
  from . import workspace
10
 
11
  client = openai.OpenAI(base_url="http://localhost:11434/v1")
12
- CACHE = {}
 
 
13
  ENV = 'LLM logic'
14
  op = ops.op_registration(ENV)
15
 
16
- @dataclasses.dataclass
17
- class Context:
18
  '''Passed to operation functions as "_ctx" if they have such a parameter.'''
19
  node: workspace.WorkspaceNode
20
- last_result = None
21
 
22
- @dataclasses.dataclass
23
- class Output:
24
  '''Return this to send values to specific outputs of a node.'''
25
  output_handle: str
26
  value: dict
27
 
28
  def chat(*args, **kwargs):
29
  key = json.dumps({'args': args, 'kwargs': kwargs})
30
- if key not in CACHE:
31
  completion = client.chat.completions.create(*args, **kwargs)
32
- CACHE[key] = [c.message.content for c in completion.choices]
33
- return CACHE[key]
34
 
35
  @op("Input")
36
  def input(*, filename: ops.PathStr, key: str):
37
  return pd.read_csv(filename).rename(columns={key: 'text'})
38
 
39
  @op("Create prompt")
40
- def create_prompt(input, *, template: ops.LongStr):
41
- assert template, 'Please specify the template. Refer to columns using their names in uppercase.'
42
- p = template
43
- for k, v in input.items():
44
- p = p.replace(k.upper(), str(v))
45
- return p
46
 
47
  @op("Ask LLM")
48
  def ask_llm(input, *, model: str, accepted_regex: str = None, max_tokens: int = 100):
@@ -74,7 +76,7 @@ def view(input, *, _ctx: Context):
74
  v = {
75
  'dataframes': { 'df': {
76
  'columns': columns,
77
- 'data': [input[c] for c in columns],
78
  }}
79
  }
80
  return v
@@ -92,12 +94,30 @@ def loop(input, *, max_iterations: int = 3, _ctx: Context):
92
  @op('Branch', outputs=['true', 'false'])
93
  def branch(input, *, expression: str):
94
  res = eval(expression, input)
95
- return Output(str(bool(res)).lower(), input)
96
 
97
  @ops.input_position(db="top")
98
  @op('RAG')
99
- def rag(input, db, *, closest_n: int=10):
100
- return input
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  @op('Run Python')
103
  def run_python(input, *, template: str):
@@ -107,16 +127,16 @@ def run_python(input, *, template: str):
107
  p = p.replace(k.upper(), str(v))
108
  return p
109
 
110
-
111
 
112
  @ops.register_executor(ENV)
113
  def execute(ws):
114
  catalog = ops.CATALOGS[ENV]
115
  nodes = {n.id: n for n in ws.nodes}
116
- contexts = {n.id: Context(n) for n in ws.nodes}
117
  edges = {n.id: [] for n in ws.nodes}
118
  for e in ws.edges:
119
- edges[e.source].append(e.target)
120
  tasks = {}
121
  NO_INPUT = object() # Marker for initial tasks.
122
  for node in ws.nodes:
@@ -125,39 +145,54 @@ def execute(ws):
125
  # Start tasks for nodes that have no inputs.
126
  if not op.inputs:
127
  tasks[node.id] = [NO_INPUT]
 
128
  # Run the rest until we run out of tasks.
129
- while tasks:
130
- n, ts = tasks.popitem()
131
- node = nodes[n]
132
- data = node.data
133
- op = catalog[data.title]
134
- params = {**data.params}
135
- if has_ctx(op):
136
- params['_ctx'] = contexts[node.id]
137
- results = []
138
- for task in ts:
139
- try:
140
- if task is NO_INPUT:
141
- result = op(**params)
142
- else:
143
- # TODO: Tasks with multiple inputs?
144
- result = op(task, **params)
145
- except Exception as e:
146
- traceback.print_exc()
147
- data.error = str(e)
148
- break
149
- contexts[node.id].last_result = result
150
- # Returned lists and DataFrames are considered multiple tasks.
151
- if isinstance(result, pd.DataFrame):
152
- result = df_to_list(result)
153
- elif not isinstance(result, list):
154
- result = [result]
155
- results.extend(result)
156
- else: # Finished all tasks without errors.
157
- if op.type == 'visualization' or op.type == 'table_view':
158
- data.display = results
159
- for target in edges[node.id]:
160
- tasks.setdefault(target, []).extend(results)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
  def df_to_list(df):
163
  return [dict(zip(df.columns, row)) for row in df.values]
@@ -165,3 +200,31 @@ def df_to_list(df):
165
  def has_ctx(op):
166
  sig = inspect.signature(op.func)
167
  return '_ctx' in sig.parameters
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  '''For specifying an LLM agent logic flow.'''
2
  from . import ops
3
+ import chromadb
4
+ import fastapi.encoders
5
  import inspect
6
+ import jinja2
7
  import json
8
  import openai
9
  import pandas as pd
10
  import traceback
11
+ import typing
12
  from . import workspace
13
 
14
  client = openai.OpenAI(base_url="http://localhost:11434/v1")
15
+ jinja = jinja2.Environment()
16
+ chroma_client = chromadb.Client()
17
+ LLM_CACHE = {}
18
  ENV = 'LLM logic'
19
  op = ops.op_registration(ENV)
20
 
21
+ class Context(ops.BaseConfig):
 
22
  '''Passed to operation functions as "_ctx" if they have such a parameter.'''
23
  node: workspace.WorkspaceNode
24
+ last_result: typing.Any = None
25
 
26
+ class Output(ops.BaseConfig):
 
27
  '''Return this to send values to specific outputs of a node.'''
28
  output_handle: str
29
  value: dict
30
 
31
  def chat(*args, **kwargs):
32
  key = json.dumps({'args': args, 'kwargs': kwargs})
33
+ if key not in LLM_CACHE:
34
  completion = client.chat.completions.create(*args, **kwargs)
35
+ LLM_CACHE[key] = [c.message.content for c in completion.choices]
36
+ return LLM_CACHE[key]
37
 
38
  @op("Input")
39
  def input(*, filename: ops.PathStr, key: str):
40
  return pd.read_csv(filename).rename(columns={key: 'text'})
41
 
42
  @op("Create prompt")
43
+ def create_prompt(input, *, save_as='prompt', template: ops.LongStr):
44
+ assert template, 'Please specify the template. Refer to columns using the Jinja2 syntax.'
45
+ t = jinja.from_string(template)
46
+ prompt = t.render(**input)
47
+ return {**input, save_as: prompt}
 
48
 
49
  @op("Ask LLM")
50
  def ask_llm(input, *, model: str, accepted_regex: str = None, max_tokens: int = 100):
 
76
  v = {
77
  'dataframes': { 'df': {
78
  'columns': columns,
79
+ 'data': [[input[c] for c in columns]],
80
  }}
81
  }
82
  return v
 
94
  @op('Branch', outputs=['true', 'false'])
95
  def branch(input, *, expression: str):
96
  res = eval(expression, input)
97
+ return Output(output_handle=str(bool(res)).lower(), value=input)
98
 
99
  @ops.input_position(db="top")
100
  @op('RAG')
101
+ def rag(input, db, *, input_field='text', db_field='text', num_matches: int=10, _ctx: Context):
102
+ last = _ctx.last_result
103
+ if last:
104
+ collection = last['_collection']
105
+ else:
106
+ collection_name = _ctx.node.id.replace(' ', '_')
107
+ for c in chroma_client.list_collections():
108
+ if c.name == collection_name:
109
+ chroma_client.delete_collection(name=collection_name)
110
+ collection = chroma_client.create_collection(name=collection_name)
111
+ collection.add(
112
+ documents=[r[db_field] for r in db],
113
+ ids=[str(i) for i in range(len(db))],
114
+ )
115
+ results = collection.query(
116
+ query_texts=[input[input_field]],
117
+ n_results=num_matches,
118
+ )
119
+ results = [db[int(r)] for r in results['ids'][0]]
120
+ return {**input, 'rag': results, '_collection': collection}
121
 
122
  @op('Run Python')
123
  def run_python(input, *, template: str):
 
127
  p = p.replace(k.upper(), str(v))
128
  return p
129
 
130
+ EXECUTOR_OUTPUT_CACHE = {}
131
 
132
  @ops.register_executor(ENV)
133
  def execute(ws):
134
  catalog = ops.CATALOGS[ENV]
135
  nodes = {n.id: n for n in ws.nodes}
136
+ contexts = {n.id: Context(node=n) for n in ws.nodes}
137
  edges = {n.id: [] for n in ws.nodes}
138
  for e in ws.edges:
139
+ edges[e.source].append(e)
140
  tasks = {}
141
  NO_INPUT = object() # Marker for initial tasks.
142
  for node in ws.nodes:
 
145
  # Start tasks for nodes that have no inputs.
146
  if not op.inputs:
147
  tasks[node.id] = [NO_INPUT]
148
+ batch_inputs = {}
149
  # Run the rest until we run out of tasks.
150
+ for stage in get_stages(ws):
151
+ next_stage = {}
152
+ while tasks:
153
+ n, ts = tasks.popitem()
154
+ if n not in stage:
155
+ next_stage.setdefault(n, []).extend(ts)
156
+ continue
157
+ node = nodes[n]
158
+ data = node.data
159
+ op = catalog[data.title]
160
+ params = {**data.params}
161
+ if has_ctx(op):
162
+ params['_ctx'] = contexts[node.id]
163
+ results = []
164
+ for task in ts:
165
+ try:
166
+ inputs = [
167
+ batch_inputs[(n, i.name)] if i.position == 'top' else task
168
+ for i in op.inputs.values()]
169
+ key = json.dumps(fastapi.encoders.jsonable_encoder((inputs, params)))
170
+ if key not in EXECUTOR_OUTPUT_CACHE:
171
+ EXECUTOR_OUTPUT_CACHE[key] = op.func(*inputs, **params)
172
+ result = EXECUTOR_OUTPUT_CACHE[key]
173
+ except Exception as e:
174
+ traceback.print_exc()
175
+ data.error = str(e)
176
+ break
177
+ contexts[node.id].last_result = result
178
+ # Returned lists and DataFrames are considered multiple tasks.
179
+ if isinstance(result, pd.DataFrame):
180
+ result = df_to_list(result)
181
+ elif not isinstance(result, list):
182
+ result = [result]
183
+ results.extend(result)
184
+ else: # Finished all tasks without errors.
185
+ if op.type == 'visualization' or op.type == 'table_view':
186
+ data.display = results[0]
187
+ for edge in edges[node.id]:
188
+ t = nodes[edge.target]
189
+ op = catalog[t.data.title]
190
+ i = op.inputs[edge.targetHandle]
191
+ if i.position == 'top':
192
+ batch_inputs.setdefault((edge.target, edge.targetHandle), []).extend(results)
193
+ else:
194
+ tasks.setdefault(edge.target, []).extend(results)
195
+ tasks = next_stage
196
 
197
  def df_to_list(df):
198
  return [dict(zip(df.columns, row)) for row in df.values]
 
200
  def has_ctx(op):
201
  sig = inspect.signature(op.func)
202
  return '_ctx' in sig.parameters
203
+
204
+ def get_stages(ws):
205
+ '''Inputs on top are batch inputs. We decompose the graph into a DAG of components along these edges.'''
206
+ catalog = ops.CATALOGS[ENV]
207
+ nodes = {n.id: n for n in ws.nodes}
208
+ batch_inputs = {}
209
+ inputs = {}
210
+ for edge in ws.edges:
211
+ inputs.setdefault(edge.target, []).append(edge.source)
212
+ node = nodes[edge.target]
213
+ op = catalog[node.data.title]
214
+ i = op.inputs[edge.targetHandle]
215
+ if i.position == 'top':
216
+ batch_inputs.setdefault(edge.target, []).append(edge.source)
217
+ stages = []
218
+ for bt, bss in batch_inputs.items():
219
+ upstream = set(bss)
220
+ new = set(bss)
221
+ while new:
222
+ n = new.pop()
223
+ for i in inputs.get(n, []):
224
+ if i not in upstream:
225
+ upstream.add(i)
226
+ new.add(i)
227
+ stages.append(upstream)
228
+ stages.sort(key=lambda s: len(s))
229
+ stages.append(set(nodes))
230
+ return stages
server/ops.py CHANGED
@@ -1,11 +1,8 @@
1
  '''API for implementing LynxKite operations.'''
2
  from __future__ import annotations
3
- import dataclasses
4
  import enum
5
  import functools
6
  import inspect
7
- import networkx as nx
8
- import pandas as pd
9
  import pydantic
10
  import typing
11
  from typing_extensions import Annotated
 
1
  '''API for implementing LynxKite operations.'''
2
  from __future__ import annotations
 
3
  import enum
4
  import functools
5
  import inspect
 
 
6
  import pydantic
7
  import typing
8
  from typing_extensions import Annotated
server/test_llm_ops.py CHANGED
@@ -2,27 +2,49 @@ import unittest
2
  from . import llm_ops
3
  from . import workspace
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  class LLMOpsTest(unittest.TestCase):
6
  def testExecute(self):
7
  ws = workspace.Workspace(env='LLM logic', nodes=[
8
- workspace.WorkspaceNode(
9
- id='0',
10
- type='basic',
11
- position=workspace.Position(x=0, y=0),
12
- data=workspace.WorkspaceNodeData(title='Input', params={
13
- 'filename': '/Users/danieldarabos/Downloads/aimo-train.csv',
14
- 'key': 'problem',
15
- })),
16
- workspace.WorkspaceNode(
17
- id='1',
18
- type='table_view',
19
- position=workspace.Position(x=0, y=0),
20
- data=workspace.WorkspaceNodeData(title='View', params={})),
21
  ], edges=[
22
- workspace.WorkspaceEdge(id='0-1', source='0', target='1', sourceHandle='', targetHandle=''),
23
  ])
24
  llm_ops.execute(ws)
25
  self.assertEqual('', ws.nodes[1].data.display)
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  if __name__ == '__main__':
28
  unittest.main()
 
2
  from . import llm_ops
3
  from . import workspace
4
 
5
+ def make_node(id, op, type='basic', **params):
6
+ return workspace.WorkspaceNode(
7
+ id=id,
8
+ type=type,
9
+ position=workspace.Position(x=0, y=0),
10
+ data=workspace.WorkspaceNodeData(title=op, params=params),
11
+ )
12
+ def make_input(id):
13
+ return make_node(
14
+ id, 'Input',
15
+ filename='/Users/danieldarabos/Downloads/aimo-train.csv',
16
+ key='problem')
17
+ def make_edge(source, target, targetHandle='input'):
18
+ return workspace.WorkspaceEdge(
19
+ id=f'{source}-{target}', source=source, target=target, sourceHandle='', targetHandle=targetHandle)
20
+
21
  class LLMOpsTest(unittest.TestCase):
22
  def testExecute(self):
23
  ws = workspace.Workspace(env='LLM logic', nodes=[
24
+ make_node(
25
+ '0', 'Input',
26
+ filename='/Users/danieldarabos/Downloads/aimo-train.csv',
27
+ key='problem'),
28
+ make_node(
29
+ '1', 'View', type='table_view'),
 
 
 
 
 
 
 
30
  ], edges=[
31
+ make_edge('0', '1')
32
  ])
33
  llm_ops.execute(ws)
34
  self.assertEqual('', ws.nodes[1].data.display)
35
 
36
+ def testStages(self):
37
+ ws = workspace.Workspace(env='LLM logic', nodes=[
38
+ make_input('in1'), make_input('in2'), make_input('in3'),
39
+ make_node('rag1', 'RAG'), make_node('rag2', 'RAG'),
40
+ make_node('p1', 'Create prompt'), make_node('p2', 'Create prompt'),
41
+ ], edges=[
42
+ make_edge('in1', 'rag1', 'db'), make_edge('in2', 'rag1'),
43
+ make_edge('rag1', 'p1'), make_edge('p1', 'rag2', 'db'),
44
+ make_edge('in3', 'p2'), make_edge('p3', 'rag2'),
45
+ ])
46
+ stages = llm_ops.get_stages(ws)
47
+ self.assertEqual('', stages)
48
+
49
  if __name__ == '__main__':
50
  unittest.main()
server/workspace.py CHANGED
@@ -4,7 +4,6 @@ import dataclasses
4
  import os
5
  import pydantic
6
  import tempfile
7
- import traceback
8
  from . import ops
9
 
10
  class BaseConfig(pydantic.BaseModel):
 
4
  import os
5
  import pydantic
6
  import tempfile
 
7
  from . import ops
8
 
9
  class BaseConfig(pydantic.BaseModel):