Spaces:
Running
Running
| # Copyright (c) Microsoft Corporation. | |
| # Licensed under the MIT license. | |
| import tensorflow as tf | |
| _counter = 0 | |
| def global_mutable_counting(): | |
| global _counter | |
| _counter += 1 | |
| return _counter | |
| class AverageMeter: | |
| def __init__(self, name): | |
| self.name = name | |
| self.val = 0 | |
| self.avg = 0 | |
| self.sum = 0 | |
| self.count = 0 | |
| def update(self, val): | |
| self.val = val | |
| self.sum += val | |
| self.count += 1 | |
| self.avg = self.sum / self.count | |
| def __str__(self): | |
| return '{name} {val:4f} ({avg:4f})'.format(**self.__dict__) | |
| def summary(self): | |
| return '{name}: {avg:4f}'.format(**self.__dict__) | |
| class AverageMeterGroup: | |
| def __init__(self): | |
| self.meters = {} | |
| def update(self, data): | |
| for k, v in data.items(): | |
| if k not in self.meters: | |
| self.meters[k] = AverageMeter(k) | |
| self.meters[k].update(v) | |
| def __str__(self): | |
| return ' '.join(str(v) for v in self.meters.values()) | |
| def summary(self): | |
| return ' '.join(v.summary() for v in self.meters.values()) | |
| class StructuredMutableTreeNode: | |
| def __init__(self, mutable): | |
| self.mutable = mutable | |
| self.children = [] | |
| def add_child(self, mutable): | |
| self.children.append(StructuredMutableTreeNode(mutable)) | |
| return self.children[-1] | |
| def type(self): | |
| return type(self.mutable) | |
| def __iter__(self): | |
| return self.traverse() | |
| def traverse(self, order="pre", deduplicate=True, memo=None): | |
| if memo is None: | |
| memo = set() | |
| assert order in ["pre", "post"] | |
| if order == "pre": | |
| if self.mutable is not None: | |
| if not deduplicate or self.mutable.key not in memo: | |
| memo.add(self.mutable.key) | |
| yield self.mutable | |
| for child in self.children: | |
| for m in child.traverse(order=order, deduplicate=deduplicate, memo=memo): | |
| yield m | |
| if order == "post": | |
| if self.mutable is not None: | |
| if not deduplicate or self.mutable.key not in memo: | |
| memo.add(self.mutable.key) | |
| yield self.mutable | |
| def fill_zero_grads(grads, weights): | |
| ret = [] | |
| for grad, weight in zip(grads, weights): | |
| if grad is not None: | |
| ret.append(grad) | |
| else: | |
| ret.append(tf.zeros_like(weight)) | |
| return ret | |