Spaces:
Running
Running
from torch.fx.experimental.graph_gradual_typechecker import Refine | |
from torch.fx.tensor_type import TensorType | |
from torch.fx.experimental.unification import Var, unify # type: ignore[attr-defined] | |
def infer_symbolic_types_single_pass(traced): | |
""" | |
Calls our symbolic inferencer once. | |
""" | |
r = Refine(traced) | |
r.refine() | |
mgu = unify_eq(r.constraints) | |
substitute_all_types(traced.graph, mgu) | |
def infer_symbolic_types(traced): | |
""" | |
Calls our symbolic inferencer twice. | |
This is useful when one pass is not enough | |
to infer all the information such as the case | |
for braodcasting. | |
""" | |
r = Refine(traced) | |
r.refine() | |
mgu = unify_eq(r.constraints) | |
substitute_all_types(traced.graph, mgu) | |
r = Refine(traced) | |
r.refine() | |
mgu = unify_eq(r.constraints) | |
substitute_all_types(traced.graph, mgu) | |
r.symbolic_relations() | |
def convert_eq(list_of_eq): | |
""" | |
Convert equality constraints in the right format | |
to be used by unification library. | |
""" | |
lhs = [] | |
rhs = [] | |
for eq in list_of_eq: | |
lhs.append(eq.lhs) | |
rhs.append(eq.rhs) | |
return tuple(lhs), tuple(rhs) | |
def unify_eq(list_of_eq): | |
""" | |
Apply unification to a set of | |
equality constraints | |
""" | |
lhs, rhs = convert_eq(list_of_eq) | |
return unify(lhs, rhs) | |
def substitute_solution_one_type(mapping, t): | |
""" | |
Apply the most general unifier to a type | |
""" | |
if isinstance(t, Var): | |
if t in mapping.keys(): | |
return mapping[t] | |
else: | |
return t | |
elif isinstance(t, TensorType): | |
new_type = [] | |
for typ in t.__args__: | |
if typ in mapping.keys(): | |
new_type.append(mapping[typ]) | |
else: | |
new_type.append(typ) | |
return TensorType(tuple(new_type)) | |
elif isinstance(t, list): | |
new_type = [] | |
for typ in t: | |
new_type.append(substitute_solution_one_type(mapping, typ)) | |
return new_type | |
elif isinstance(t, tuple): | |
new_type = [] | |
for typ in t: | |
new_type.append(substitute_solution_one_type(mapping, typ)) | |
return tuple(new_type) | |
else: | |
return t | |
def substitute_all_types(graph, mapping): | |
""" | |
Apply the most general unifier to all types in a graph | |
till reaching a fixed point. If the input and output graph | |
are the same, we converge. | |
""" | |
flag = True | |
while flag: | |
flag = False | |
for k in mapping: | |
old_mapping_val = mapping[k] | |
if mapping[k] in mapping.keys(): | |
new_key = mapping[k] | |
mapping[k] = mapping[new_key] | |
if old_mapping_val != mapping[k]: | |
flag = True | |
for n in graph.nodes: | |
n.type = substitute_solution_one_type(mapping, n.type) | |
def check_for_type_equality(g1, g2): | |
""" | |
A check equality to be used in fixed points. | |
We do not use graph equality but instead type | |
equality. | |
""" | |
for n, m in zip(g1.nodes, g2.nodes): | |
if n.type != m.type: | |
return False | |
return True | |