darabos commited on
Commit
a509341
·
1 Parent(s): 2d3da64

Optional caching for one-by-one executor.

Browse files
Files changed (1) hide show
  1. server/executors/one_by_one.py +11 -4
server/executors/one_by_one.py CHANGED
@@ -25,9 +25,16 @@ def has_ctx(op):
25
  sig = inspect.signature(op.func)
26
  return '_ctx' in sig.parameters
27
 
28
- def register(env: str):
 
 
29
  '''Registers the one-by-one executor.'''
30
- ops.EXECUTORS[env] = execute
 
 
 
 
 
31
 
32
  def get_stages(ws, catalog):
33
  '''Inputs on top are batch inputs. We decompose the graph into a DAG of components along these edges.'''
@@ -93,13 +100,13 @@ def execute(ws, catalog, cache=None):
93
  inputs = [
94
  batch_inputs[(n, i.name)] if i.position == 'top' else task
95
  for i in op.inputs.values()]
96
- key = json.dumps(fastapi.encoders.jsonable_encoder((inputs, params)))
97
  if cache:
 
98
  if key not in cache:
99
  cache[key] = op.func(*inputs, **params)
100
  result = cache[key]
101
  else:
102
- result = op.func(*inputs, **params)
103
  except Exception as e:
104
  traceback.print_exc()
105
  data.error = str(e)
 
25
  sig = inspect.signature(op.func)
26
  return '_ctx' in sig.parameters
27
 
28
+ CACHES = {}
29
+
30
+ def register(env: str, cache: bool = True):
31
  '''Registers the one-by-one executor.'''
32
+ if cache:
33
+ CACHES[env] = {}
34
+ cache = CACHES[env]
35
+ else:
36
+ cache = None
37
+ ops.EXECUTORS[env] = lambda ws: execute(ws, ops.CATALOGS[env], cache=cache)
38
 
39
  def get_stages(ws, catalog):
40
  '''Inputs on top are batch inputs. We decompose the graph into a DAG of components along these edges.'''
 
100
  inputs = [
101
  batch_inputs[(n, i.name)] if i.position == 'top' else task
102
  for i in op.inputs.values()]
 
103
  if cache:
104
+ key = json.dumps(fastapi.encoders.jsonable_encoder((inputs, params)))
105
  if key not in cache:
106
  cache[key] = op.func(*inputs, **params)
107
  result = cache[key]
108
  else:
109
+ result = op(*inputs, **params)
110
  except Exception as e:
111
  traceback.print_exc()
112
  data.error = str(e)