"""For specifying an LLM agent logic flow.""" from . import ops import chromadb import enum import jinja2 import json import openai import numpy as np import pandas as pd from .executors import one_by_one chat_client = openai.OpenAI(base_url="http://localhost:8080/v1") embedding_client = openai.OpenAI(base_url="http://localhost:7997/") jinja = jinja2.Environment() chroma_client = chromadb.Client() LLM_CACHE = {} ENV = "LLM logic" one_by_one.register(ENV) op = ops.op_registration(ENV) def chat(*args, **kwargs): key = json.dumps({"method": "chat", "args": args, "kwargs": kwargs}) if key not in LLM_CACHE: completion = chat_client.chat.completions.create(*args, **kwargs) LLM_CACHE[key] = [c.message.content for c in completion.choices] return LLM_CACHE[key] def embedding(*args, **kwargs): key = json.dumps({"method": "embedding", "args": args, "kwargs": kwargs}) if key not in LLM_CACHE: res = embedding_client.embeddings.create(*args, **kwargs) [data] = res.data LLM_CACHE[key] = data.embedding return LLM_CACHE[key] @op("Input CSV") def input_csv(*, filename: ops.PathStr, key: str): return pd.read_csv(filename).rename(columns={key: "text"}) @op("Input document") def input_document(*, filename: ops.PathStr): with open(filename) as f: return {"text": f.read()} @op("Input chat") def input_chat(*, chat: str): return {"text": chat} @op("Split document") def split_document(input, *, delimiter: str = "\\n\\n"): delimiter = delimiter.encode().decode("unicode_escape") chunks = input["text"].split(delimiter) return pd.DataFrame(chunks, columns=["text"]) @ops.input_position(input="top") @op("Build document graph") def build_document_graph(input): return [{"source": i, "target": i + 1} for i in range(len(input) - 1)] @ops.input_position(nodes="top", edges="top") @op("Predict links") def predict_links(nodes, edges): """A placeholder for a real algorithm. For now just adds 2-hop neighbors.""" edge_map = {} # Source -> [Targets] for edge in edges: edge_map.setdefault(edge["source"], []) edge_map[edge["source"]].append(edge["target"]) new_edges = [] for edge in edges: for t in edge_map.get(edge["target"], []): new_edges.append({"source": edge["source"], "target": t}) return edges + new_edges @ops.input_position(nodes="top", edges="top") @op("Add neighbors") def add_neighbors(nodes, edges, item): nodes = pd.DataFrame(nodes) edges = pd.DataFrame(edges) matches = item["rag"] additional_matches = [] for m in matches: node = nodes[nodes["text"] == m].index[0] neighbors = edges[edges["source"] == node]["target"].to_list() additional_matches.extend(nodes.loc[neighbors, "text"]) return {**item, "rag": matches + additional_matches} @op("Create prompt") def create_prompt(input, *, save_as="prompt", template: ops.LongStr): assert ( template ), "Please specify the template. Refer to columns using the Jinja2 syntax." t = jinja.from_string(template) prompt = t.render(**input) return {**input, save_as: prompt} @op("Ask LLM") def ask_llm(input, *, model: str, accepted_regex: str = None, max_tokens: int = 100): assert model, "Please specify the model." assert "prompt" in input, "Please create the prompt first." options = {} if accepted_regex: options["extra_body"] = { "guided_regex": accepted_regex, } results = chat( model=model, max_tokens=max_tokens, messages=[ {"role": "user", "content": input["prompt"]}, ], **options, ) return [{**input, "response": r} for r in results] @op("View", view="table_view") def view(input, *, _ctx: one_by_one.Context): v = _ctx.last_result if v: columns = v["dataframes"]["df"]["columns"] v["dataframes"]["df"]["data"].append([input[c] for c in columns]) else: columns = [str(c) for c in input.keys() if not str(c).startswith("_")] v = { "dataframes": { "df": { "columns": columns, "data": [[input[c] for c in columns]], } } } return v @ops.input_position(input="right") @ops.output_position(output="left") @op("Loop") def loop(input, *, max_iterations: int = 3, _ctx: one_by_one.Context): """Data can flow back here max_iterations-1 times.""" key = f"iterations-{_ctx.node.id}" input[key] = input.get(key, 0) + 1 if input[key] < max_iterations: return input @op("Branch", outputs=["true", "false"]) def branch(input, *, expression: str): res = eval(expression, input) return one_by_one.Output(output_handle=str(bool(res)).lower(), value=input) class RagEngine(enum.Enum): Chroma = "Chroma" Custom = "Custom" @ops.input_position(db="top") @op("RAG") def rag( input, db, *, engine: RagEngine = RagEngine.Chroma, input_field="text", db_field="text", num_matches: int = 10, _ctx: one_by_one.Context, ): if engine == RagEngine.Chroma: last = _ctx.last_result if last: collection = last["_collection"] else: collection_name = _ctx.node.id.replace(" ", "_") for c in chroma_client.list_collections(): if c.name == collection_name: chroma_client.delete_collection(name=collection_name) collection = chroma_client.create_collection(name=collection_name) collection.add( documents=[r[db_field] for r in db], ids=[str(i) for i in range(len(db))], ) results = collection.query( query_texts=[input[input_field]], n_results=num_matches, ) results = [db[int(r)] for r in results["ids"][0]] return {**input, "rag": results, "_collection": collection} if engine == RagEngine.Custom: model = "google/gemma-2-2b-it" chat = input[input_field] embeddings = [embedding(input=[r[db_field]], model=model) for r in db] q = embedding(input=[chat], model=model) def cosine_similarity(a, b): return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) scores = [(i, cosine_similarity(q, e)) for i, e in enumerate(embeddings)] scores.sort(key=lambda x: -x[1]) matches = [db[i][db_field] for i, _ in scores[:num_matches]] return {**input, "rag": matches} @op("Run Python") def run_python(input, *, template: str): """TODO: Implement.""" return input