|
|
|
|
|
import numpy as np |
|
import logging |
|
|
|
class MolNode: |
|
def __init__(self, mol, init_value, parent=None, is_known=False, |
|
zero_known_value=True): |
|
self.mol = mol |
|
self.pred_value = init_value |
|
self.value = init_value |
|
self.succ_value = np.inf |
|
self.parent = parent |
|
|
|
self.id = -1 |
|
if self.parent is None: |
|
self.depth = 0 |
|
else: |
|
self.depth = self.parent.depth |
|
|
|
self.is_known = is_known |
|
self.children = [] |
|
self.succ = is_known |
|
self.open = True |
|
if is_known: |
|
self.open = False |
|
if zero_known_value: |
|
self.value = 0 |
|
self.succ_value = self.value |
|
|
|
if parent is not None: |
|
parent.children.append(self) |
|
|
|
def v_self(self): |
|
""" |
|
:return: V_self(self | subtree) |
|
""" |
|
return self.value |
|
|
|
def v_target(self): |
|
""" |
|
:return: V_target(self | whole tree) |
|
""" |
|
if self.parent is None: |
|
return self.value |
|
else: |
|
return self.parent.v_target() |
|
|
|
def init_values(self, no_child=False): |
|
assert self.open and (no_child or self.children) |
|
|
|
new_value = np.inf |
|
self.succ = False |
|
for reaction in self.children: |
|
new_value = np.min((new_value, reaction.v_self())) |
|
self.succ |= reaction.succ |
|
|
|
v_delta = new_value - self.value |
|
self.value = new_value |
|
|
|
if self.succ: |
|
for reaction in self.children: |
|
self.succ_value = np.min((self.succ_value, |
|
reaction.succ_value)) |
|
|
|
self.open = False |
|
|
|
return v_delta |
|
|
|
def backup(self, succ): |
|
assert not self.is_known |
|
|
|
new_value = np.inf |
|
for reaction in self.children: |
|
new_value = np.min((new_value, reaction.v_self())) |
|
new_succ = self.succ | succ |
|
updated = (self.value != new_value) or (self.succ != new_succ) |
|
|
|
new_succ_value = np.inf |
|
if new_succ: |
|
for reaction in self.children: |
|
new_succ_value = np.min((new_succ_value, reaction.succ_value)) |
|
updated = updated or (self.succ_value != new_succ_value) |
|
|
|
v_delta = new_value - self.value |
|
self.value = new_value |
|
self.succ = new_succ |
|
self.succ_value = new_succ_value |
|
|
|
if updated and self.parent: |
|
return self.parent.backup(v_delta, from_mol=self.mol) |
|
|
|
def serialize(self): |
|
text = '%d | %s' % (self.id, self.mol) |
|
return text |
|
|
|
def get_ancestors(self): |
|
if self.parent is None: |
|
return {self.mol} |
|
|
|
ancestors = self.parent.parent.get_ancestors() |
|
ancestors.add(self.mol) |
|
return ancestors |