Spaces:
				
			
			
	
			
			
		Build error
		
	
	
	
			
			
	
	
	
	
		
		
		Build error
		
	Boxes for graph RAG.
Browse files- server/executors/one_by_one.py +1 -1
 - server/llm_ops.py +102 -30
 - server/ops.py +2 -1
 
    	
        server/executors/one_by_one.py
    CHANGED
    
    | 
         @@ -19,7 +19,7 @@ class Output(ops.BaseConfig): 
     | 
|
| 19 | 
         | 
| 20 | 
         | 
| 21 | 
         
             
            def df_to_list(df):
         
     | 
| 22 | 
         
            -
              return  
     | 
| 23 | 
         | 
| 24 | 
         
             
            def has_ctx(op):
         
     | 
| 25 | 
         
             
              sig = inspect.signature(op.func)
         
     | 
| 
         | 
|
| 19 | 
         | 
| 20 | 
         | 
| 21 | 
         
             
            def df_to_list(df):
         
     | 
| 22 | 
         
            +
              return df.to_dict(orient='records')
         
     | 
| 23 | 
         | 
| 24 | 
         
             
            def has_ctx(op):
         
     | 
| 25 | 
         
             
              sig = inspect.signature(op.func)
         
     | 
    	
        server/llm_ops.py
    CHANGED
    
    | 
         @@ -1,13 +1,15 @@ 
     | 
|
| 1 | 
         
             
            '''For specifying an LLM agent logic flow.'''
         
     | 
| 2 | 
         
             
            from . import ops
         
     | 
| 3 | 
         
             
            import chromadb
         
     | 
| 
         | 
|
| 4 | 
         
             
            import jinja2
         
     | 
| 5 | 
         
             
            import json
         
     | 
| 6 | 
         
             
            import openai
         
     | 
| 
         | 
|
| 7 | 
         
             
            import pandas as pd
         
     | 
| 8 | 
         
             
            from .executors import one_by_one
         
     | 
| 9 | 
         | 
| 10 | 
         
            -
            client = openai.OpenAI(base_url="http://localhost: 
     | 
| 11 | 
         
             
            jinja = jinja2.Environment()
         
     | 
| 12 | 
         
             
            chroma_client = chromadb.Client()
         
     | 
| 13 | 
         
             
            LLM_CACHE = {}
         
     | 
| 
         @@ -16,16 +18,71 @@ one_by_one.register(ENV) 
     | 
|
| 16 | 
         
             
            op = ops.op_registration(ENV)
         
     | 
| 17 | 
         | 
| 18 | 
         
             
            def chat(*args, **kwargs):
         
     | 
| 19 | 
         
            -
              key = json.dumps({'args': args, 'kwargs': kwargs})
         
     | 
| 20 | 
         
             
              if key not in LLM_CACHE:
         
     | 
| 21 | 
         
             
                completion = client.chat.completions.create(*args, **kwargs)
         
     | 
| 22 | 
         
             
                LLM_CACHE[key] = [c.message.content for c in completion.choices]
         
     | 
| 23 | 
         
             
              return LLM_CACHE[key]
         
     | 
| 24 | 
         | 
| 25 | 
         
            -
             
     | 
| 26 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 27 | 
         
             
              return pd.read_csv(filename).rename(columns={key: 'text'})
         
     | 
| 28 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 29 | 
         
             
            @op("Create prompt")
         
     | 
| 30 | 
         
             
            def create_prompt(input, *, save_as='prompt', template: ops.LongStr):
         
     | 
| 31 | 
         
             
              assert template, 'Please specify the template. Refer to columns using the Jinja2 syntax.'
         
     | 
| 
         @@ -83,35 +140,50 @@ def branch(input, *, expression: str): 
     | 
|
| 83 | 
         
             
              res = eval(expression, input)
         
     | 
| 84 | 
         
             
              return one_by_one.Output(output_handle=str(bool(res)).lower(), value=input)
         
     | 
| 85 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 86 | 
         
             
            @ops.input_position(db="top")
         
     | 
| 87 | 
         
             
            @op('RAG')
         
     | 
| 88 | 
         
            -
            def rag( 
     | 
| 89 | 
         
            -
               
     | 
| 90 | 
         
            -
               
     | 
| 91 | 
         
            -
             
     | 
| 92 | 
         
            -
               
     | 
| 93 | 
         
            -
             
     | 
| 94 | 
         
            -
                 
     | 
| 95 | 
         
            -
             
     | 
| 96 | 
         
            -
             
     | 
| 97 | 
         
            -
                 
     | 
| 98 | 
         
            -
             
     | 
| 99 | 
         
            -
                   
     | 
| 100 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 101 | 
         
             
                )
         
     | 
| 102 | 
         
            -
             
     | 
| 103 | 
         
            -
                 
     | 
| 104 | 
         
            -
             
     | 
| 105 | 
         
            -
             
     | 
| 106 | 
         
            -
             
     | 
| 107 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 108 | 
         | 
| 109 | 
         
             
            @op('Run Python')
         
     | 
| 110 | 
         
             
            def run_python(input, *, template: str):
         
     | 
| 111 | 
         
            -
               
     | 
| 112 | 
         
            -
               
     | 
| 113 | 
         
            -
              for k, v in input.items():
         
     | 
| 114 | 
         
            -
                p = p.replace(k.upper(), str(v))
         
     | 
| 115 | 
         
            -
              return p
         
     | 
| 116 | 
         
            -
             
     | 
| 117 | 
         
            -
             
     | 
| 
         | 
|
| 1 | 
         
             
            '''For specifying an LLM agent logic flow.'''
         
     | 
| 2 | 
         
             
            from . import ops
         
     | 
| 3 | 
         
             
            import chromadb
         
     | 
| 4 | 
         
            +
            import enum
         
     | 
| 5 | 
         
             
            import jinja2
         
     | 
| 6 | 
         
             
            import json
         
     | 
| 7 | 
         
             
            import openai
         
     | 
| 8 | 
         
            +
            import numpy as np
         
     | 
| 9 | 
         
             
            import pandas as pd
         
     | 
| 10 | 
         
             
            from .executors import one_by_one
         
     | 
| 11 | 
         | 
| 12 | 
         
            +
            client = openai.OpenAI(base_url="http://localhost:7997/")
         
     | 
| 13 | 
         
             
            jinja = jinja2.Environment()
         
     | 
| 14 | 
         
             
            chroma_client = chromadb.Client()
         
     | 
| 15 | 
         
             
            LLM_CACHE = {}
         
     | 
| 
         | 
|
| 18 | 
         
             
            op = ops.op_registration(ENV)
         
     | 
| 19 | 
         | 
| 20 | 
         
             
            def chat(*args, **kwargs):
         
     | 
| 21 | 
         
            +
              key = json.dumps({'method': 'chat', 'args': args, 'kwargs': kwargs})
         
     | 
| 22 | 
         
             
              if key not in LLM_CACHE:
         
     | 
| 23 | 
         
             
                completion = client.chat.completions.create(*args, **kwargs)
         
     | 
| 24 | 
         
             
                LLM_CACHE[key] = [c.message.content for c in completion.choices]
         
     | 
| 25 | 
         
             
              return LLM_CACHE[key]
         
     | 
| 26 | 
         | 
| 27 | 
         
            +
            def embedding(*args, **kwargs):
         
     | 
| 28 | 
         
            +
              key = json.dumps({'method': 'embedding', 'args': args, 'kwargs': kwargs})
         
     | 
| 29 | 
         
            +
              if key not in LLM_CACHE:
         
     | 
| 30 | 
         
            +
                res = client.embeddings.create(*args, **kwargs)
         
     | 
| 31 | 
         
            +
                [data] = res.data
         
     | 
| 32 | 
         
            +
                LLM_CACHE[key] = data.embedding
         
     | 
| 33 | 
         
            +
              return LLM_CACHE[key]
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            @op("Input CSV")
         
     | 
| 36 | 
         
            +
            def input_csv(*, filename: ops.PathStr, key: str):
         
     | 
| 37 | 
         
             
              return pd.read_csv(filename).rename(columns={key: 'text'})
         
     | 
| 38 | 
         | 
| 39 | 
         
            +
            @op("Input document")
         
     | 
| 40 | 
         
            +
            def input_document(*, filename: ops.PathStr):
         
     | 
| 41 | 
         
            +
              with open(filename) as f:
         
     | 
| 42 | 
         
            +
                return {'text': f.read()}
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
            @op("Input chat")
         
     | 
| 45 | 
         
            +
            def input_chat(*, chat: str):
         
     | 
| 46 | 
         
            +
              return {'text': chat}
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
            @op("Split document")
         
     | 
| 49 | 
         
            +
            def split_document(input, *, delimiter: str = '\\n\\n'):
         
     | 
| 50 | 
         
            +
              delimiter = delimiter.encode().decode('unicode_escape')
         
     | 
| 51 | 
         
            +
              chunks = input['text'].split(delimiter)
         
     | 
| 52 | 
         
            +
              return pd.DataFrame(chunks, columns=['text'])
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
            @ops.input_position(input="top")
         
     | 
| 55 | 
         
            +
            @op("Build document graph")
         
     | 
| 56 | 
         
            +
            def build_document_graph(input):
         
     | 
| 57 | 
         
            +
              chunks = input['text']
         
     | 
| 58 | 
         
            +
              return pd.DataFrame([{'source': i, 'target': i+1} for i in range(len(chunks)-1)]),
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
            @ops.input_position(nodes="top", edges="top")
         
     | 
| 61 | 
         
            +
            @op("Predict links")
         
     | 
| 62 | 
         
            +
            def predict_links(nodes, edges):
         
     | 
| 63 | 
         
            +
              '''A placeholder for a real algorithm. For now just adds 2-hop neighbors.'''
         
     | 
| 64 | 
         
            +
              edges = edges.to_dict(orient='records')
         
     | 
| 65 | 
         
            +
              edge_map = {} # Source -> [Targets]
         
     | 
| 66 | 
         
            +
              for edge in edges:
         
     | 
| 67 | 
         
            +
                edge_map.setdefault(edge['source'], [])
         
     | 
| 68 | 
         
            +
                edge_map[edge['source']].append(edge['target'])
         
     | 
| 69 | 
         
            +
              new_edges = []
         
     | 
| 70 | 
         
            +
              for source, target in edges.items():
         
     | 
| 71 | 
         
            +
                for t in edge_map.get(target, []):
         
     | 
| 72 | 
         
            +
                  new_edges.append({'source': source, 'target': t})
         
     | 
| 73 | 
         
            +
              return pd.DataFrame(edges.append(new_edges))
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
            @ops.input_position(nodes="top", edges="top")
         
     | 
| 76 | 
         
            +
            @op("Add neighbors")
         
     | 
| 77 | 
         
            +
            def add_neighbors(nodes, edges, item):
         
     | 
| 78 | 
         
            +
              matches = item['rag']
         
     | 
| 79 | 
         
            +
              additional_matches = []
         
     | 
| 80 | 
         
            +
              for m in matches:
         
     | 
| 81 | 
         
            +
                node = nodes[nodes['text'] == m].index[0]
         
     | 
| 82 | 
         
            +
                neighbors = edges[edges['source'] == node]['target']
         
     | 
| 83 | 
         
            +
                additional_matches.extend(nodes.loc[neighbors, 'text'])
         
     | 
| 84 | 
         
            +
              return {**item, 'rag': matches + additional_matches}
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
             
            @op("Create prompt")
         
     | 
| 87 | 
         
             
            def create_prompt(input, *, save_as='prompt', template: ops.LongStr):
         
     | 
| 88 | 
         
             
              assert template, 'Please specify the template. Refer to columns using the Jinja2 syntax.'
         
     | 
| 
         | 
|
| 140 | 
         
             
              res = eval(expression, input)
         
     | 
| 141 | 
         
             
              return one_by_one.Output(output_handle=str(bool(res)).lower(), value=input)
         
     | 
| 142 | 
         | 
| 143 | 
         
            +
            class RagEngine(enum.Enum):
         
     | 
| 144 | 
         
            +
              Chroma = 'Chroma'
         
     | 
| 145 | 
         
            +
              Custom = 'Custom'
         
     | 
| 146 | 
         
            +
             
     | 
| 147 | 
         
             
            @ops.input_position(db="top")
         
     | 
| 148 | 
         
             
            @op('RAG')
         
     | 
| 149 | 
         
            +
            def rag(
         
     | 
| 150 | 
         
            +
              input, db, *,
         
     | 
| 151 | 
         
            +
              engine: RagEngine = RagEngine.Chroma,
         
     | 
| 152 | 
         
            +
              input_field='text', db_field='text', num_matches: int = 10,
         
     | 
| 153 | 
         
            +
              _ctx: one_by_one.Context):
         
     | 
| 154 | 
         
            +
              if engine == RagEngine.Chroma:
         
     | 
| 155 | 
         
            +
                last = _ctx.last_result
         
     | 
| 156 | 
         
            +
                if last:
         
     | 
| 157 | 
         
            +
                  collection = last['_collection']
         
     | 
| 158 | 
         
            +
                else:
         
     | 
| 159 | 
         
            +
                  collection_name = _ctx.node.id.replace(' ', '_')
         
     | 
| 160 | 
         
            +
                  for c in chroma_client.list_collections():
         
     | 
| 161 | 
         
            +
                    if c.name == collection_name:
         
     | 
| 162 | 
         
            +
                      chroma_client.delete_collection(name=collection_name)
         
     | 
| 163 | 
         
            +
                  collection = chroma_client.create_collection(name=collection_name)
         
     | 
| 164 | 
         
            +
                  collection.add(
         
     | 
| 165 | 
         
            +
                    documents=[r[db_field] for r in db],
         
     | 
| 166 | 
         
            +
                    ids=[str(i) for i in range(len(db))],
         
     | 
| 167 | 
         
            +
                  )
         
     | 
| 168 | 
         
            +
                results = collection.query(
         
     | 
| 169 | 
         
            +
                  query_texts=[input[input_field]],
         
     | 
| 170 | 
         
            +
                  n_results=num_matches,
         
     | 
| 171 | 
         
             
                )
         
     | 
| 172 | 
         
            +
                results = [db[int(r)] for r in results['ids'][0]]
         
     | 
| 173 | 
         
            +
                return {**input, 'rag': results, '_collection': collection}
         
     | 
| 174 | 
         
            +
              if engine == RagEngine.Custom:
         
     | 
| 175 | 
         
            +
                model = 'google/gemma-2-2b-it'
         
     | 
| 176 | 
         
            +
                chat = input[input_field]
         
     | 
| 177 | 
         
            +
                embeddings = [embedding(input=[r[db_field]], model=model) for r in db]
         
     | 
| 178 | 
         
            +
                q = embedding(input=[chat], model=model)
         
     | 
| 179 | 
         
            +
                def cosine_similarity(a, b):
         
     | 
| 180 | 
         
            +
                  return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
         
     | 
| 181 | 
         
            +
                scores = [(i, cosine_similarity(q, e)) for i, e in enumerate(embeddings)]
         
     | 
| 182 | 
         
            +
                scores.sort(key=lambda x: -x[1])
         
     | 
| 183 | 
         
            +
                matches = [db[i][db_field] for i, _ in scores[:num_matches]]
         
     | 
| 184 | 
         
            +
                return {**input, 'rag': matches}
         
     | 
| 185 | 
         | 
| 186 | 
         
             
            @op('Run Python')
         
     | 
| 187 | 
         
             
            def run_python(input, *, template: str):
         
     | 
| 188 | 
         
            +
              '''TODO: Implement.'''
         
     | 
| 189 | 
         
            +
              return input
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        server/ops.py
    CHANGED
    
    | 
         @@ -98,11 +98,12 @@ class Op(BaseConfig): 
     | 
|
| 98 | 
         
             
                      params[p] = int(params[p])
         
     | 
| 99 | 
         
             
                    elif self.params[p].type == float:
         
     | 
| 100 | 
         
             
                      params[p] = float(params[p])
         
     | 
| 
         | 
|
| 
         | 
|
| 101 | 
         
             
                res = self.func(*inputs, **params)
         
     | 
| 102 | 
         
             
                return res
         
     | 
| 103 | 
         | 
| 104 | 
         | 
| 105 | 
         
            -
             
     | 
| 106 | 
         
             
            def op(env: str, name: str, *, view='basic', sub_nodes=None, outputs=None):
         
     | 
| 107 | 
         
             
              '''Decorator for defining an operation.'''
         
     | 
| 108 | 
         
             
              def decorator(func):
         
     | 
| 
         | 
|
| 98 | 
         
             
                      params[p] = int(params[p])
         
     | 
| 99 | 
         
             
                    elif self.params[p].type == float:
         
     | 
| 100 | 
         
             
                      params[p] = float(params[p])
         
     | 
| 101 | 
         
            +
                    elif isinstance(self.params[p].type, enum.EnumMeta):
         
     | 
| 102 | 
         
            +
                      params[p] = self.params[p].type[params[p]]
         
     | 
| 103 | 
         
             
                res = self.func(*inputs, **params)
         
     | 
| 104 | 
         
             
                return res
         
     | 
| 105 | 
         | 
| 106 | 
         | 
| 
         | 
|
| 107 | 
         
             
            def op(env: str, name: str, *, view='basic', sub_nodes=None, outputs=None):
         
     | 
| 108 | 
         
             
              '''Decorator for defining an operation.'''
         
     | 
| 109 | 
         
             
              def decorator(func):
         
     |