Spaces:
Running
Running
File size: 3,640 Bytes
dc3ebef 6988728 dc3ebef e8a8341 dc3ebef 6988728 a07e9cb e8a8341 a07e9cb dc3ebef 6988728 dc3ebef 6988728 dc3ebef e7fa7ee db436f7 dc3ebef e7fa7ee 6988728 dc3ebef e7fa7ee dc3ebef a07e9cb e7fa7ee a07e9cb e7fa7ee dc3ebef a07e9cb dc3ebef e7fa7ee e8a8341 a07e9cb 6988728 a07e9cb dc3ebef 4524b65 e7fa7ee e8a8341 a07e9cb e8a8341 a07e9cb e8a8341 6988728 a07e9cb 6988728 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
'''For specifying an LLM agent logic flow.'''
from . import ops
import chromadb
import jinja2
import json
import openai
import pandas as pd
from .executors import one_by_one
client = openai.OpenAI(base_url="http://localhost:11434/v1")
jinja = jinja2.Environment()
chroma_client = chromadb.Client()
LLM_CACHE = {}
ENV = 'LLM logic'
one_by_one.register(ENV)
op = ops.op_registration(ENV)
def chat(*args, **kwargs):
key = json.dumps({'args': args, 'kwargs': kwargs})
if key not in LLM_CACHE:
completion = client.chat.completions.create(*args, **kwargs)
LLM_CACHE[key] = [c.message.content for c in completion.choices]
return LLM_CACHE[key]
@op("Input")
def input(*, filename: ops.PathStr, key: str):
return pd.read_csv(filename).rename(columns={key: 'text'})
@op("Create prompt")
def create_prompt(input, *, save_as='prompt', template: ops.LongStr):
assert template, 'Please specify the template. Refer to columns using the Jinja2 syntax.'
t = jinja.from_string(template)
prompt = t.render(**input)
return {**input, save_as: prompt}
@op("Ask LLM")
def ask_llm(input, *, model: str, accepted_regex: str = None, max_tokens: int = 100):
assert model, 'Please specify the model.'
assert 'prompt' in input, 'Please create the prompt first.'
options = {}
if accepted_regex:
options['extra_body'] = {
"guided_regex": accepted_regex,
}
results = chat(
model=model,
max_tokens=max_tokens,
messages=[
{"role": "user", "content": input['prompt']},
],
**options,
)
return [{**input, 'response': r} for r in results]
@op("View", view="table_view")
def view(input, *, _ctx: one_by_one.Context):
v = _ctx.last_result
if v:
columns = v['dataframes']['df']['columns']
v['dataframes']['df']['data'].append([input[c] for c in columns])
else:
columns = [str(c) for c in input.keys() if not str(c).startswith('_')]
v = {
'dataframes': { 'df': {
'columns': columns,
'data': [[input[c] for c in columns]],
}}
}
return v
@ops.input_position(input="right")
@ops.output_position(output="left")
@op("Loop")
def loop(input, *, max_iterations: int = 3, _ctx: one_by_one.Context):
'''Data can flow back here max_iterations-1 times.'''
key = f'iterations-{_ctx.node.id}'
input[key] = input.get(key, 0) + 1
if input[key] < max_iterations:
return input
@op('Branch', outputs=['true', 'false'])
def branch(input, *, expression: str):
res = eval(expression, input)
return one_by_one.Output(output_handle=str(bool(res)).lower(), value=input)
@ops.input_position(db="top")
@op('RAG')
def rag(input, db, *, input_field='text', db_field='text', num_matches: int=10, _ctx: one_by_one.Context):
last = _ctx.last_result
if last:
collection = last['_collection']
else:
collection_name = _ctx.node.id.replace(' ', '_')
for c in chroma_client.list_collections():
if c.name == collection_name:
chroma_client.delete_collection(name=collection_name)
collection = chroma_client.create_collection(name=collection_name)
collection.add(
documents=[r[db_field] for r in db],
ids=[str(i) for i in range(len(db))],
)
results = collection.query(
query_texts=[input[input_field]],
n_results=num_matches,
)
results = [db[int(r)] for r in results['ids'][0]]
return {**input, 'rag': results, '_collection': collection}
@op('Run Python')
def run_python(input, *, template: str):
assert template, 'Please specify the template. Refer to columns using their names in uppercase.'
p = template
for k, v in input.items():
p = p.replace(k.upper(), str(v))
return p
|