File size: 3,640 Bytes
dc3ebef
 
6988728
 
dc3ebef
 
 
e8a8341
dc3ebef
 
6988728
 
 
a07e9cb
e8a8341
a07e9cb
dc3ebef
 
 
6988728
dc3ebef
6988728
 
dc3ebef
e7fa7ee
db436f7
dc3ebef
 
e7fa7ee
6988728
 
 
 
 
dc3ebef
e7fa7ee
 
dc3ebef
a07e9cb
 
e7fa7ee
a07e9cb
e7fa7ee
dc3ebef
a07e9cb
 
 
 
 
 
 
 
 
dc3ebef
e7fa7ee
e8a8341
a07e9cb
 
 
 
 
 
 
 
 
6988728
a07e9cb
 
dc3ebef
4524b65
 
 
e7fa7ee
e8a8341
a07e9cb
 
 
 
 
 
 
 
 
e8a8341
a07e9cb
 
 
e8a8341
6988728
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a07e9cb
 
 
 
 
 
 
 
 
6988728
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
'''For specifying an LLM agent logic flow.'''
from . import ops
import chromadb
import jinja2
import json
import openai
import pandas as pd
from .executors import one_by_one

client = openai.OpenAI(base_url="http://localhost:11434/v1")
jinja = jinja2.Environment()
chroma_client = chromadb.Client()
LLM_CACHE = {}
ENV = 'LLM logic'
one_by_one.register(ENV)
op = ops.op_registration(ENV)

def chat(*args, **kwargs):
  key = json.dumps({'args': args, 'kwargs': kwargs})
  if key not in LLM_CACHE:
    completion = client.chat.completions.create(*args, **kwargs)
    LLM_CACHE[key] = [c.message.content for c in completion.choices]
  return LLM_CACHE[key]

@op("Input")
def input(*, filename: ops.PathStr, key: str):
  return pd.read_csv(filename).rename(columns={key: 'text'})

@op("Create prompt")
def create_prompt(input, *, save_as='prompt', template: ops.LongStr):
  assert template, 'Please specify the template. Refer to columns using the Jinja2 syntax.'
  t = jinja.from_string(template)
  prompt = t.render(**input)
  return {**input, save_as: prompt}

@op("Ask LLM")
def ask_llm(input, *, model: str, accepted_regex: str = None, max_tokens: int = 100):
  assert model, 'Please specify the model.'
  assert 'prompt' in input, 'Please create the prompt first.'
  options = {}
  if accepted_regex:
    options['extra_body'] = {
      "guided_regex": accepted_regex,
    }
  results = chat(
    model=model,
    max_tokens=max_tokens,
    messages=[
      {"role": "user", "content": input['prompt']},
    ],
    **options,
  )
  return [{**input, 'response': r} for r in results]

@op("View", view="table_view")
def view(input, *, _ctx: one_by_one.Context):
  v = _ctx.last_result
  if v:
    columns = v['dataframes']['df']['columns']
    v['dataframes']['df']['data'].append([input[c] for c in columns])
  else:
    columns = [str(c) for c in input.keys() if not str(c).startswith('_')]
    v = {
      'dataframes': { 'df': {
        'columns': columns,
        'data': [[input[c] for c in columns]],
      }}
    }
  return v

@ops.input_position(input="right")
@ops.output_position(output="left")
@op("Loop")
def loop(input, *, max_iterations: int = 3, _ctx: one_by_one.Context):
  '''Data can flow back here max_iterations-1 times.'''
  key = f'iterations-{_ctx.node.id}'
  input[key] = input.get(key, 0) + 1
  if input[key] < max_iterations:
    return input

@op('Branch', outputs=['true', 'false'])
def branch(input, *, expression: str):
  res = eval(expression, input)
  return one_by_one.Output(output_handle=str(bool(res)).lower(), value=input)

@ops.input_position(db="top")
@op('RAG')
def rag(input, db, *, input_field='text', db_field='text', num_matches: int=10, _ctx: one_by_one.Context):
  last = _ctx.last_result
  if last:
    collection = last['_collection']
  else:
    collection_name = _ctx.node.id.replace(' ', '_')
    for c in chroma_client.list_collections():
      if c.name == collection_name:
        chroma_client.delete_collection(name=collection_name)
    collection = chroma_client.create_collection(name=collection_name)
    collection.add(
      documents=[r[db_field] for r in db],
      ids=[str(i) for i in range(len(db))],
    )
  results = collection.query(
    query_texts=[input[input_field]],
    n_results=num_matches,
  )
  results = [db[int(r)] for r in results['ids'][0]]
  return {**input, 'rag': results, '_collection': collection}

@op('Run Python')
def run_python(input, *, template: str):
  assert template, 'Please specify the template. Refer to columns using their names in uppercase.'
  p = template
  for k, v in input.items():
    p = p.replace(k.upper(), str(v))
  return p