File size: 6,194 Bytes
dc3ebef
 
6988728
03a6805
6988728
dc3ebef
 
03a6805
dc3ebef
e8a8341
dc3ebef
ad438f1
 
6988728
 
 
a07e9cb
e8a8341
a07e9cb
dc3ebef
 
03a6805
6988728
ad438f1
6988728
 
dc3ebef
03a6805
 
 
ad438f1
03a6805
 
 
 
 
 
dc3ebef
 
03a6805
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eda8f97
03a6805
 
 
 
 
 
 
 
 
 
eda8f97
 
 
 
03a6805
 
 
 
eda8f97
 
03a6805
 
 
 
eda8f97
03a6805
 
 
e7fa7ee
6988728
 
 
 
 
dc3ebef
e7fa7ee
 
dc3ebef
a07e9cb
 
e7fa7ee
a07e9cb
e7fa7ee
dc3ebef
a07e9cb
 
 
 
 
 
 
 
 
dc3ebef
e7fa7ee
e8a8341
a07e9cb
 
 
 
 
 
 
 
 
6988728
a07e9cb
 
dc3ebef
4524b65
 
 
e7fa7ee
e8a8341
a07e9cb
 
 
 
 
 
 
 
 
e8a8341
a07e9cb
03a6805
 
 
 
a07e9cb
 
03a6805
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6988728
03a6805
 
 
 
 
 
 
 
 
 
 
 
 
a07e9cb
 
 
03a6805
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
'''For specifying an LLM agent logic flow.'''
from . import ops
import chromadb
import enum
import jinja2
import json
import openai
import numpy as np
import pandas as pd
from .executors import one_by_one

chat_client = openai.OpenAI(base_url="http://localhost:8080/v1")
embedding_client = openai.OpenAI(base_url="http://localhost:7997/")
jinja = jinja2.Environment()
chroma_client = chromadb.Client()
LLM_CACHE = {}
ENV = 'LLM logic'
one_by_one.register(ENV)
op = ops.op_registration(ENV)

def chat(*args, **kwargs):
  key = json.dumps({'method': 'chat', 'args': args, 'kwargs': kwargs})
  if key not in LLM_CACHE:
    completion = chat_client.chat.completions.create(*args, **kwargs)
    LLM_CACHE[key] = [c.message.content for c in completion.choices]
  return LLM_CACHE[key]

def embedding(*args, **kwargs):
  key = json.dumps({'method': 'embedding', 'args': args, 'kwargs': kwargs})
  if key not in LLM_CACHE:
    res = embedding_client.embeddings.create(*args, **kwargs)
    [data] = res.data
    LLM_CACHE[key] = data.embedding
  return LLM_CACHE[key]

@op("Input CSV")
def input_csv(*, filename: ops.PathStr, key: str):
  return pd.read_csv(filename).rename(columns={key: 'text'})

@op("Input document")
def input_document(*, filename: ops.PathStr):
  with open(filename) as f:
    return {'text': f.read()}

@op("Input chat")
def input_chat(*, chat: str):
  return {'text': chat}

@op("Split document")
def split_document(input, *, delimiter: str = '\\n\\n'):
  delimiter = delimiter.encode().decode('unicode_escape')
  chunks = input['text'].split(delimiter)
  return pd.DataFrame(chunks, columns=['text'])

@ops.input_position(input="top")
@op("Build document graph")
def build_document_graph(input):
  return [{'source': i, 'target': i+1} for i in range(len(input)-1)]

@ops.input_position(nodes="top", edges="top")
@op("Predict links")
def predict_links(nodes, edges):
  '''A placeholder for a real algorithm. For now just adds 2-hop neighbors.'''
  edge_map = {} # Source -> [Targets]
  for edge in edges:
    edge_map.setdefault(edge['source'], [])
    edge_map[edge['source']].append(edge['target'])
  new_edges = []
  for edge in edges:
    for t in edge_map.get(edge['target'], []):
      new_edges.append({'source': edge['source'], 'target': t})
  return edges + new_edges

@ops.input_position(nodes="top", edges="top")
@op("Add neighbors")
def add_neighbors(nodes, edges, item):
  nodes = pd.DataFrame(nodes)
  edges = pd.DataFrame(edges)
  matches = item['rag']
  additional_matches = []
  for m in matches:
    node = nodes[nodes['text'] == m].index[0]
    neighbors = edges[edges['source'] == node]['target'].to_list()
    additional_matches.extend(nodes.loc[neighbors, 'text'])
  return {**item, 'rag': matches + additional_matches}

@op("Create prompt")
def create_prompt(input, *, save_as='prompt', template: ops.LongStr):
  assert template, 'Please specify the template. Refer to columns using the Jinja2 syntax.'
  t = jinja.from_string(template)
  prompt = t.render(**input)
  return {**input, save_as: prompt}

@op("Ask LLM")
def ask_llm(input, *, model: str, accepted_regex: str = None, max_tokens: int = 100):
  assert model, 'Please specify the model.'
  assert 'prompt' in input, 'Please create the prompt first.'
  options = {}
  if accepted_regex:
    options['extra_body'] = {
      "guided_regex": accepted_regex,
    }
  results = chat(
    model=model,
    max_tokens=max_tokens,
    messages=[
      {"role": "user", "content": input['prompt']},
    ],
    **options,
  )
  return [{**input, 'response': r} for r in results]

@op("View", view="table_view")
def view(input, *, _ctx: one_by_one.Context):
  v = _ctx.last_result
  if v:
    columns = v['dataframes']['df']['columns']
    v['dataframes']['df']['data'].append([input[c] for c in columns])
  else:
    columns = [str(c) for c in input.keys() if not str(c).startswith('_')]
    v = {
      'dataframes': { 'df': {
        'columns': columns,
        'data': [[input[c] for c in columns]],
      }}
    }
  return v

@ops.input_position(input="right")
@ops.output_position(output="left")
@op("Loop")
def loop(input, *, max_iterations: int = 3, _ctx: one_by_one.Context):
  '''Data can flow back here max_iterations-1 times.'''
  key = f'iterations-{_ctx.node.id}'
  input[key] = input.get(key, 0) + 1
  if input[key] < max_iterations:
    return input

@op('Branch', outputs=['true', 'false'])
def branch(input, *, expression: str):
  res = eval(expression, input)
  return one_by_one.Output(output_handle=str(bool(res)).lower(), value=input)

class RagEngine(enum.Enum):
  Chroma = 'Chroma'
  Custom = 'Custom'

@ops.input_position(db="top")
@op('RAG')
def rag(
  input, db, *,
  engine: RagEngine = RagEngine.Chroma,
  input_field='text', db_field='text', num_matches: int = 10,
  _ctx: one_by_one.Context):
  if engine == RagEngine.Chroma:
    last = _ctx.last_result
    if last:
      collection = last['_collection']
    else:
      collection_name = _ctx.node.id.replace(' ', '_')
      for c in chroma_client.list_collections():
        if c.name == collection_name:
          chroma_client.delete_collection(name=collection_name)
      collection = chroma_client.create_collection(name=collection_name)
      collection.add(
        documents=[r[db_field] for r in db],
        ids=[str(i) for i in range(len(db))],
      )
    results = collection.query(
      query_texts=[input[input_field]],
      n_results=num_matches,
    )
    results = [db[int(r)] for r in results['ids'][0]]
    return {**input, 'rag': results, '_collection': collection}
  if engine == RagEngine.Custom:
    model = 'google/gemma-2-2b-it'
    chat = input[input_field]
    embeddings = [embedding(input=[r[db_field]], model=model) for r in db]
    q = embedding(input=[chat], model=model)
    def cosine_similarity(a, b):
      return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
    scores = [(i, cosine_similarity(q, e)) for i, e in enumerate(embeddings)]
    scores.sort(key=lambda x: -x[1])
    matches = [db[i][db_field] for i, _ in scores[:num_matches]]
    return {**input, 'rag': matches}

@op('Run Python')
def run_python(input, *, template: str):
  '''TODO: Implement.'''
  return input