darabos commited on
Commit
03a6805
·
1 Parent(s): 336b73c

Boxes for graph RAG.

Browse files
server/executors/one_by_one.py CHANGED
@@ -19,7 +19,7 @@ class Output(ops.BaseConfig):
19
 
20
 
21
  def df_to_list(df):
22
- return [dict(zip(df.columns, row)) for row in df.values]
23
 
24
  def has_ctx(op):
25
  sig = inspect.signature(op.func)
 
19
 
20
 
21
  def df_to_list(df):
22
+ return df.to_dict(orient='records')
23
 
24
  def has_ctx(op):
25
  sig = inspect.signature(op.func)
server/llm_ops.py CHANGED
@@ -1,13 +1,15 @@
1
  '''For specifying an LLM agent logic flow.'''
2
  from . import ops
3
  import chromadb
 
4
  import jinja2
5
  import json
6
  import openai
 
7
  import pandas as pd
8
  from .executors import one_by_one
9
 
10
- client = openai.OpenAI(base_url="http://localhost:11434/v1")
11
  jinja = jinja2.Environment()
12
  chroma_client = chromadb.Client()
13
  LLM_CACHE = {}
@@ -16,16 +18,71 @@ one_by_one.register(ENV)
16
  op = ops.op_registration(ENV)
17
 
18
  def chat(*args, **kwargs):
19
- key = json.dumps({'args': args, 'kwargs': kwargs})
20
  if key not in LLM_CACHE:
21
  completion = client.chat.completions.create(*args, **kwargs)
22
  LLM_CACHE[key] = [c.message.content for c in completion.choices]
23
  return LLM_CACHE[key]
24
 
25
- @op("Input")
26
- def input(*, filename: ops.PathStr, key: str):
 
 
 
 
 
 
 
 
27
  return pd.read_csv(filename).rename(columns={key: 'text'})
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  @op("Create prompt")
30
  def create_prompt(input, *, save_as='prompt', template: ops.LongStr):
31
  assert template, 'Please specify the template. Refer to columns using the Jinja2 syntax.'
@@ -83,35 +140,50 @@ def branch(input, *, expression: str):
83
  res = eval(expression, input)
84
  return one_by_one.Output(output_handle=str(bool(res)).lower(), value=input)
85
 
 
 
 
 
86
  @ops.input_position(db="top")
87
  @op('RAG')
88
- def rag(input, db, *, input_field='text', db_field='text', num_matches: int=10, _ctx: one_by_one.Context):
89
- last = _ctx.last_result
90
- if last:
91
- collection = last['_collection']
92
- else:
93
- collection_name = _ctx.node.id.replace(' ', '_')
94
- for c in chroma_client.list_collections():
95
- if c.name == collection_name:
96
- chroma_client.delete_collection(name=collection_name)
97
- collection = chroma_client.create_collection(name=collection_name)
98
- collection.add(
99
- documents=[r[db_field] for r in db],
100
- ids=[str(i) for i in range(len(db))],
 
 
 
 
 
 
 
 
 
101
  )
102
- results = collection.query(
103
- query_texts=[input[input_field]],
104
- n_results=num_matches,
105
- )
106
- results = [db[int(r)] for r in results['ids'][0]]
107
- return {**input, 'rag': results, '_collection': collection}
 
 
 
 
 
 
 
108
 
109
  @op('Run Python')
110
  def run_python(input, *, template: str):
111
- assert template, 'Please specify the template. Refer to columns using their names in uppercase.'
112
- p = template
113
- for k, v in input.items():
114
- p = p.replace(k.upper(), str(v))
115
- return p
116
-
117
-
 
1
  '''For specifying an LLM agent logic flow.'''
2
  from . import ops
3
  import chromadb
4
+ import enum
5
  import jinja2
6
  import json
7
  import openai
8
+ import numpy as np
9
  import pandas as pd
10
  from .executors import one_by_one
11
 
12
+ client = openai.OpenAI(base_url="http://localhost:7997/")
13
  jinja = jinja2.Environment()
14
  chroma_client = chromadb.Client()
15
  LLM_CACHE = {}
 
18
  op = ops.op_registration(ENV)
19
 
20
  def chat(*args, **kwargs):
21
+ key = json.dumps({'method': 'chat', 'args': args, 'kwargs': kwargs})
22
  if key not in LLM_CACHE:
23
  completion = client.chat.completions.create(*args, **kwargs)
24
  LLM_CACHE[key] = [c.message.content for c in completion.choices]
25
  return LLM_CACHE[key]
26
 
27
+ def embedding(*args, **kwargs):
28
+ key = json.dumps({'method': 'embedding', 'args': args, 'kwargs': kwargs})
29
+ if key not in LLM_CACHE:
30
+ res = client.embeddings.create(*args, **kwargs)
31
+ [data] = res.data
32
+ LLM_CACHE[key] = data.embedding
33
+ return LLM_CACHE[key]
34
+
35
+ @op("Input CSV")
36
+ def input_csv(*, filename: ops.PathStr, key: str):
37
  return pd.read_csv(filename).rename(columns={key: 'text'})
38
 
39
+ @op("Input document")
40
+ def input_document(*, filename: ops.PathStr):
41
+ with open(filename) as f:
42
+ return {'text': f.read()}
43
+
44
+ @op("Input chat")
45
+ def input_chat(*, chat: str):
46
+ return {'text': chat}
47
+
48
+ @op("Split document")
49
+ def split_document(input, *, delimiter: str = '\\n\\n'):
50
+ delimiter = delimiter.encode().decode('unicode_escape')
51
+ chunks = input['text'].split(delimiter)
52
+ return pd.DataFrame(chunks, columns=['text'])
53
+
54
+ @ops.input_position(input="top")
55
+ @op("Build document graph")
56
+ def build_document_graph(input):
57
+ chunks = input['text']
58
+ return pd.DataFrame([{'source': i, 'target': i+1} for i in range(len(chunks)-1)]),
59
+
60
+ @ops.input_position(nodes="top", edges="top")
61
+ @op("Predict links")
62
+ def predict_links(nodes, edges):
63
+ '''A placeholder for a real algorithm. For now just adds 2-hop neighbors.'''
64
+ edges = edges.to_dict(orient='records')
65
+ edge_map = {} # Source -> [Targets]
66
+ for edge in edges:
67
+ edge_map.setdefault(edge['source'], [])
68
+ edge_map[edge['source']].append(edge['target'])
69
+ new_edges = []
70
+ for source, target in edges.items():
71
+ for t in edge_map.get(target, []):
72
+ new_edges.append({'source': source, 'target': t})
73
+ return pd.DataFrame(edges.append(new_edges))
74
+
75
+ @ops.input_position(nodes="top", edges="top")
76
+ @op("Add neighbors")
77
+ def add_neighbors(nodes, edges, item):
78
+ matches = item['rag']
79
+ additional_matches = []
80
+ for m in matches:
81
+ node = nodes[nodes['text'] == m].index[0]
82
+ neighbors = edges[edges['source'] == node]['target']
83
+ additional_matches.extend(nodes.loc[neighbors, 'text'])
84
+ return {**item, 'rag': matches + additional_matches}
85
+
86
  @op("Create prompt")
87
  def create_prompt(input, *, save_as='prompt', template: ops.LongStr):
88
  assert template, 'Please specify the template. Refer to columns using the Jinja2 syntax.'
 
140
  res = eval(expression, input)
141
  return one_by_one.Output(output_handle=str(bool(res)).lower(), value=input)
142
 
143
+ class RagEngine(enum.Enum):
144
+ Chroma = 'Chroma'
145
+ Custom = 'Custom'
146
+
147
  @ops.input_position(db="top")
148
  @op('RAG')
149
+ def rag(
150
+ input, db, *,
151
+ engine: RagEngine = RagEngine.Chroma,
152
+ input_field='text', db_field='text', num_matches: int = 10,
153
+ _ctx: one_by_one.Context):
154
+ if engine == RagEngine.Chroma:
155
+ last = _ctx.last_result
156
+ if last:
157
+ collection = last['_collection']
158
+ else:
159
+ collection_name = _ctx.node.id.replace(' ', '_')
160
+ for c in chroma_client.list_collections():
161
+ if c.name == collection_name:
162
+ chroma_client.delete_collection(name=collection_name)
163
+ collection = chroma_client.create_collection(name=collection_name)
164
+ collection.add(
165
+ documents=[r[db_field] for r in db],
166
+ ids=[str(i) for i in range(len(db))],
167
+ )
168
+ results = collection.query(
169
+ query_texts=[input[input_field]],
170
+ n_results=num_matches,
171
  )
172
+ results = [db[int(r)] for r in results['ids'][0]]
173
+ return {**input, 'rag': results, '_collection': collection}
174
+ if engine == RagEngine.Custom:
175
+ model = 'google/gemma-2-2b-it'
176
+ chat = input[input_field]
177
+ embeddings = [embedding(input=[r[db_field]], model=model) for r in db]
178
+ q = embedding(input=[chat], model=model)
179
+ def cosine_similarity(a, b):
180
+ return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
181
+ scores = [(i, cosine_similarity(q, e)) for i, e in enumerate(embeddings)]
182
+ scores.sort(key=lambda x: -x[1])
183
+ matches = [db[i][db_field] for i, _ in scores[:num_matches]]
184
+ return {**input, 'rag': matches}
185
 
186
  @op('Run Python')
187
  def run_python(input, *, template: str):
188
+ '''TODO: Implement.'''
189
+ return input
 
 
 
 
 
server/ops.py CHANGED
@@ -98,11 +98,12 @@ class Op(BaseConfig):
98
  params[p] = int(params[p])
99
  elif self.params[p].type == float:
100
  params[p] = float(params[p])
 
 
101
  res = self.func(*inputs, **params)
102
  return res
103
 
104
 
105
-
106
  def op(env: str, name: str, *, view='basic', sub_nodes=None, outputs=None):
107
  '''Decorator for defining an operation.'''
108
  def decorator(func):
 
98
  params[p] = int(params[p])
99
  elif self.params[p].type == float:
100
  params[p] = float(params[p])
101
+ elif isinstance(self.params[p].type, enum.EnumMeta):
102
+ params[p] = self.params[p].type[params[p]]
103
  res = self.func(*inputs, **params)
104
  return res
105
 
106
 
 
107
  def op(env: str, name: str, *, view='basic', sub_nodes=None, outputs=None):
108
  '''Decorator for defining an operation.'''
109
  def decorator(func):