darabos commited on
Commit
0832c91
·
1 Parent(s): d60fc1c

Format/lint changes.

Browse files
server/llm_ops.py CHANGED
@@ -1,4 +1,5 @@
1
- '''For specifying an LLM agent logic flow.'''
 
2
  from . import ops
3
  import chromadb
4
  import enum
@@ -14,177 +15,205 @@ embedding_client = openai.OpenAI(base_url="http://localhost:7997/")
14
  jinja = jinja2.Environment()
15
  chroma_client = chromadb.Client()
16
  LLM_CACHE = {}
17
- ENV = 'LLM logic'
18
  one_by_one.register(ENV)
19
  op = ops.op_registration(ENV)
20
 
 
21
  def chat(*args, **kwargs):
22
- key = json.dumps({'method': 'chat', 'args': args, 'kwargs': kwargs})
23
- if key not in LLM_CACHE:
24
- completion = chat_client.chat.completions.create(*args, **kwargs)
25
- LLM_CACHE[key] = [c.message.content for c in completion.choices]
26
- return LLM_CACHE[key]
 
27
 
28
  def embedding(*args, **kwargs):
29
- key = json.dumps({'method': 'embedding', 'args': args, 'kwargs': kwargs})
30
- if key not in LLM_CACHE:
31
- res = embedding_client.embeddings.create(*args, **kwargs)
32
- [data] = res.data
33
- LLM_CACHE[key] = data.embedding
34
- return LLM_CACHE[key]
 
35
 
36
  @op("Input CSV")
37
  def input_csv(*, filename: ops.PathStr, key: str):
38
- return pd.read_csv(filename).rename(columns={key: 'text'})
 
39
 
40
  @op("Input document")
41
  def input_document(*, filename: ops.PathStr):
42
- with open(filename) as f:
43
- return {'text': f.read()}
 
44
 
45
  @op("Input chat")
46
  def input_chat(*, chat: str):
47
- return {'text': chat}
 
48
 
49
  @op("Split document")
50
- def split_document(input, *, delimiter: str = '\\n\\n'):
51
- delimiter = delimiter.encode().decode('unicode_escape')
52
- chunks = input['text'].split(delimiter)
53
- return pd.DataFrame(chunks, columns=['text'])
 
54
 
55
  @ops.input_position(input="top")
56
  @op("Build document graph")
57
  def build_document_graph(input):
58
- return [{'source': i, 'target': i+1} for i in range(len(input)-1)]
 
59
 
60
  @ops.input_position(nodes="top", edges="top")
61
  @op("Predict links")
62
  def predict_links(nodes, edges):
63
- '''A placeholder for a real algorithm. For now just adds 2-hop neighbors.'''
64
- edge_map = {} # Source -> [Targets]
65
- for edge in edges:
66
- edge_map.setdefault(edge['source'], [])
67
- edge_map[edge['source']].append(edge['target'])
68
- new_edges = []
69
- for edge in edges:
70
- for t in edge_map.get(edge['target'], []):
71
- new_edges.append({'source': edge['source'], 'target': t})
72
- return edges + new_edges
 
73
 
74
  @ops.input_position(nodes="top", edges="top")
75
  @op("Add neighbors")
76
  def add_neighbors(nodes, edges, item):
77
- nodes = pd.DataFrame(nodes)
78
- edges = pd.DataFrame(edges)
79
- matches = item['rag']
80
- additional_matches = []
81
- for m in matches:
82
- node = nodes[nodes['text'] == m].index[0]
83
- neighbors = edges[edges['source'] == node]['target'].to_list()
84
- additional_matches.extend(nodes.loc[neighbors, 'text'])
85
- return {**item, 'rag': matches + additional_matches}
 
86
 
87
  @op("Create prompt")
88
- def create_prompt(input, *, save_as='prompt', template: ops.LongStr):
89
- assert template, 'Please specify the template. Refer to columns using the Jinja2 syntax.'
90
- t = jinja.from_string(template)
91
- prompt = t.render(**input)
92
- return {**input, save_as: prompt}
 
 
 
93
 
94
  @op("Ask LLM")
95
  def ask_llm(input, *, model: str, accepted_regex: str = None, max_tokens: int = 100):
96
- assert model, 'Please specify the model.'
97
- assert 'prompt' in input, 'Please create the prompt first.'
98
- options = {}
99
- if accepted_regex:
100
- options['extra_body'] = {
101
- "guided_regex": accepted_regex,
102
- }
103
- results = chat(
104
- model=model,
105
- max_tokens=max_tokens,
106
- messages=[
107
- {"role": "user", "content": input['prompt']},
108
- ],
109
- **options,
110
- )
111
- return [{**input, 'response': r} for r in results]
 
112
 
113
  @op("View", view="table_view")
114
  def view(input, *, _ctx: one_by_one.Context):
115
- v = _ctx.last_result
116
- if v:
117
- columns = v['dataframes']['df']['columns']
118
- v['dataframes']['df']['data'].append([input[c] for c in columns])
119
- else:
120
- columns = [str(c) for c in input.keys() if not str(c).startswith('_')]
121
- v = {
122
- 'dataframes': { 'df': {
123
- 'columns': columns,
124
- 'data': [[input[c] for c in columns]],
125
- }}
126
- }
127
- return v
 
 
 
128
 
129
  @ops.input_position(input="right")
130
  @ops.output_position(output="left")
131
  @op("Loop")
132
  def loop(input, *, max_iterations: int = 3, _ctx: one_by_one.Context):
133
- '''Data can flow back here max_iterations-1 times.'''
134
- key = f'iterations-{_ctx.node.id}'
135
- input[key] = input.get(key, 0) + 1
136
- if input[key] < max_iterations:
137
- return input
 
138
 
139
- @op('Branch', outputs=['true', 'false'])
140
  def branch(input, *, expression: str):
141
- res = eval(expression, input)
142
- return one_by_one.Output(output_handle=str(bool(res)).lower(), value=input)
 
143
 
144
  class RagEngine(enum.Enum):
145
- Chroma = 'Chroma'
146
- Custom = 'Custom'
 
147
 
148
  @ops.input_position(db="top")
149
- @op('RAG')
150
  def rag(
151
- input, db, *,
152
- engine: RagEngine = RagEngine.Chroma,
153
- input_field='text', db_field='text', num_matches: int = 10,
154
- _ctx: one_by_one.Context):
155
- if engine == RagEngine.Chroma:
156
- last = _ctx.last_result
157
- if last:
158
- collection = last['_collection']
159
- else:
160
- collection_name = _ctx.node.id.replace(' ', '_')
161
- for c in chroma_client.list_collections():
162
- if c.name == collection_name:
163
- chroma_client.delete_collection(name=collection_name)
164
- collection = chroma_client.create_collection(name=collection_name)
165
- collection.add(
166
- documents=[r[db_field] for r in db],
167
- ids=[str(i) for i in range(len(db))],
168
- )
169
- results = collection.query(
170
- query_texts=[input[input_field]],
171
- n_results=num_matches,
172
- )
173
- results = [db[int(r)] for r in results['ids'][0]]
174
- return {**input, 'rag': results, '_collection': collection}
175
- if engine == RagEngine.Custom:
176
- model = 'google/gemma-2-2b-it'
177
- chat = input[input_field]
178
- embeddings = [embedding(input=[r[db_field]], model=model) for r in db]
179
- q = embedding(input=[chat], model=model)
180
- def cosine_similarity(a, b):
181
- return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
182
- scores = [(i, cosine_similarity(q, e)) for i, e in enumerate(embeddings)]
183
- scores.sort(key=lambda x: -x[1])
184
- matches = [db[i][db_field] for i, _ in scores[:num_matches]]
185
- return {**input, 'rag': matches}
186
-
187
- @op('Run Python')
 
 
 
 
 
 
 
 
188
  def run_python(input, *, template: str):
189
- '''TODO: Implement.'''
190
- return input
 
1
+ """For specifying an LLM agent logic flow."""
2
+
3
  from . import ops
4
  import chromadb
5
  import enum
 
15
  jinja = jinja2.Environment()
16
  chroma_client = chromadb.Client()
17
  LLM_CACHE = {}
18
+ ENV = "LLM logic"
19
  one_by_one.register(ENV)
20
  op = ops.op_registration(ENV)
21
 
22
+
23
  def chat(*args, **kwargs):
24
+ key = json.dumps({"method": "chat", "args": args, "kwargs": kwargs})
25
+ if key not in LLM_CACHE:
26
+ completion = chat_client.chat.completions.create(*args, **kwargs)
27
+ LLM_CACHE[key] = [c.message.content for c in completion.choices]
28
+ return LLM_CACHE[key]
29
+
30
 
31
  def embedding(*args, **kwargs):
32
+ key = json.dumps({"method": "embedding", "args": args, "kwargs": kwargs})
33
+ if key not in LLM_CACHE:
34
+ res = embedding_client.embeddings.create(*args, **kwargs)
35
+ [data] = res.data
36
+ LLM_CACHE[key] = data.embedding
37
+ return LLM_CACHE[key]
38
+
39
 
40
  @op("Input CSV")
41
  def input_csv(*, filename: ops.PathStr, key: str):
42
+ return pd.read_csv(filename).rename(columns={key: "text"})
43
+
44
 
45
  @op("Input document")
46
  def input_document(*, filename: ops.PathStr):
47
+ with open(filename) as f:
48
+ return {"text": f.read()}
49
+
50
 
51
  @op("Input chat")
52
  def input_chat(*, chat: str):
53
+ return {"text": chat}
54
+
55
 
56
  @op("Split document")
57
+ def split_document(input, *, delimiter: str = "\\n\\n"):
58
+ delimiter = delimiter.encode().decode("unicode_escape")
59
+ chunks = input["text"].split(delimiter)
60
+ return pd.DataFrame(chunks, columns=["text"])
61
+
62
 
63
  @ops.input_position(input="top")
64
  @op("Build document graph")
65
  def build_document_graph(input):
66
+ return [{"source": i, "target": i + 1} for i in range(len(input) - 1)]
67
+
68
 
69
  @ops.input_position(nodes="top", edges="top")
70
  @op("Predict links")
71
  def predict_links(nodes, edges):
72
+ """A placeholder for a real algorithm. For now just adds 2-hop neighbors."""
73
+ edge_map = {} # Source -> [Targets]
74
+ for edge in edges:
75
+ edge_map.setdefault(edge["source"], [])
76
+ edge_map[edge["source"]].append(edge["target"])
77
+ new_edges = []
78
+ for edge in edges:
79
+ for t in edge_map.get(edge["target"], []):
80
+ new_edges.append({"source": edge["source"], "target": t})
81
+ return edges + new_edges
82
+
83
 
84
  @ops.input_position(nodes="top", edges="top")
85
  @op("Add neighbors")
86
  def add_neighbors(nodes, edges, item):
87
+ nodes = pd.DataFrame(nodes)
88
+ edges = pd.DataFrame(edges)
89
+ matches = item["rag"]
90
+ additional_matches = []
91
+ for m in matches:
92
+ node = nodes[nodes["text"] == m].index[0]
93
+ neighbors = edges[edges["source"] == node]["target"].to_list()
94
+ additional_matches.extend(nodes.loc[neighbors, "text"])
95
+ return {**item, "rag": matches + additional_matches}
96
+
97
 
98
  @op("Create prompt")
99
+ def create_prompt(input, *, save_as="prompt", template: ops.LongStr):
100
+ assert (
101
+ template
102
+ ), "Please specify the template. Refer to columns using the Jinja2 syntax."
103
+ t = jinja.from_string(template)
104
+ prompt = t.render(**input)
105
+ return {**input, save_as: prompt}
106
+
107
 
108
  @op("Ask LLM")
109
  def ask_llm(input, *, model: str, accepted_regex: str = None, max_tokens: int = 100):
110
+ assert model, "Please specify the model."
111
+ assert "prompt" in input, "Please create the prompt first."
112
+ options = {}
113
+ if accepted_regex:
114
+ options["extra_body"] = {
115
+ "guided_regex": accepted_regex,
116
+ }
117
+ results = chat(
118
+ model=model,
119
+ max_tokens=max_tokens,
120
+ messages=[
121
+ {"role": "user", "content": input["prompt"]},
122
+ ],
123
+ **options,
124
+ )
125
+ return [{**input, "response": r} for r in results]
126
+
127
 
128
  @op("View", view="table_view")
129
  def view(input, *, _ctx: one_by_one.Context):
130
+ v = _ctx.last_result
131
+ if v:
132
+ columns = v["dataframes"]["df"]["columns"]
133
+ v["dataframes"]["df"]["data"].append([input[c] for c in columns])
134
+ else:
135
+ columns = [str(c) for c in input.keys() if not str(c).startswith("_")]
136
+ v = {
137
+ "dataframes": {
138
+ "df": {
139
+ "columns": columns,
140
+ "data": [[input[c] for c in columns]],
141
+ }
142
+ }
143
+ }
144
+ return v
145
+
146
 
147
  @ops.input_position(input="right")
148
  @ops.output_position(output="left")
149
  @op("Loop")
150
  def loop(input, *, max_iterations: int = 3, _ctx: one_by_one.Context):
151
+ """Data can flow back here max_iterations-1 times."""
152
+ key = f"iterations-{_ctx.node.id}"
153
+ input[key] = input.get(key, 0) + 1
154
+ if input[key] < max_iterations:
155
+ return input
156
+
157
 
158
+ @op("Branch", outputs=["true", "false"])
159
  def branch(input, *, expression: str):
160
+ res = eval(expression, input)
161
+ return one_by_one.Output(output_handle=str(bool(res)).lower(), value=input)
162
+
163
 
164
  class RagEngine(enum.Enum):
165
+ Chroma = "Chroma"
166
+ Custom = "Custom"
167
+
168
 
169
  @ops.input_position(db="top")
170
+ @op("RAG")
171
  def rag(
172
+ input,
173
+ db,
174
+ *,
175
+ engine: RagEngine = RagEngine.Chroma,
176
+ input_field="text",
177
+ db_field="text",
178
+ num_matches: int = 10,
179
+ _ctx: one_by_one.Context,
180
+ ):
181
+ if engine == RagEngine.Chroma:
182
+ last = _ctx.last_result
183
+ if last:
184
+ collection = last["_collection"]
185
+ else:
186
+ collection_name = _ctx.node.id.replace(" ", "_")
187
+ for c in chroma_client.list_collections():
188
+ if c.name == collection_name:
189
+ chroma_client.delete_collection(name=collection_name)
190
+ collection = chroma_client.create_collection(name=collection_name)
191
+ collection.add(
192
+ documents=[r[db_field] for r in db],
193
+ ids=[str(i) for i in range(len(db))],
194
+ )
195
+ results = collection.query(
196
+ query_texts=[input[input_field]],
197
+ n_results=num_matches,
198
+ )
199
+ results = [db[int(r)] for r in results["ids"][0]]
200
+ return {**input, "rag": results, "_collection": collection}
201
+ if engine == RagEngine.Custom:
202
+ model = "google/gemma-2-2b-it"
203
+ chat = input[input_field]
204
+ embeddings = [embedding(input=[r[db_field]], model=model) for r in db]
205
+ q = embedding(input=[chat], model=model)
206
+
207
+ def cosine_similarity(a, b):
208
+ return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
209
+
210
+ scores = [(i, cosine_similarity(q, e)) for i, e in enumerate(embeddings)]
211
+ scores.sort(key=lambda x: -x[1])
212
+ matches = [db[i][db_field] for i, _ in scores[:num_matches]]
213
+ return {**input, "rag": matches}
214
+
215
+
216
+ @op("Run Python")
217
  def run_python(input, *, template: str):
218
+ """TODO: Implement."""
219
+ return input
web/src/workspace/nodes/LynxKiteNode.tsx CHANGED
@@ -1,5 +1,3 @@
1
- import { useContext } from 'react';
2
- import { LynxKiteState } from '../LynxKiteState';
3
  import { useReactFlow, Handle, NodeResizeControl, Position } from '@xyflow/react';
4
  // @ts-ignore
5
  import ChevronDownRight from '~icons/tabler/chevron-down-right.jsx';
@@ -45,7 +43,6 @@ function getHandles(inputs: object, outputs: object) {
45
  export default function LynxKiteNode(props: LynxKiteNodeProps) {
46
  const reactFlow = useReactFlow();
47
  const data = props.data;
48
- const state = useContext(LynxKiteState);
49
  const expanded = !data.collapsed;
50
  const handles = getHandles(data.meta?.inputs || {}, data.meta?.outputs || {});
51
  function asPx(n: number | undefined) {
 
 
 
1
  import { useReactFlow, Handle, NodeResizeControl, Position } from '@xyflow/react';
2
  // @ts-ignore
3
  import ChevronDownRight from '~icons/tabler/chevron-down-right.jsx';
 
43
  export default function LynxKiteNode(props: LynxKiteNodeProps) {
44
  const reactFlow = useReactFlow();
45
  const data = props.data;
 
46
  const expanded = !data.collapsed;
47
  const handles = getHandles(data.meta?.inputs || {}, data.meta?.outputs || {});
48
  function asPx(n: number | undefined) {