|
import copy |
|
from core.leras import nn |
|
tf = nn.tf |
|
|
|
class OptimizerBase(nn.Saveable): |
|
def __init__(self, name=None): |
|
super().__init__(name=name) |
|
|
|
def tf_clip_norm(self, g, c, n): |
|
"""Clip the gradient `g` if the L2 norm `n` exceeds `c`. |
|
# Arguments |
|
g: Tensor, the gradient tensor |
|
c: float >= 0. Gradients will be clipped |
|
when their L2 norm exceeds this value. |
|
n: Tensor, actual norm of `g`. |
|
# Returns |
|
Tensor, the gradient clipped if required. |
|
""" |
|
if c <= 0: |
|
return g |
|
|
|
condition = n >= c |
|
then_expression = tf.scalar_mul(c / n, g) |
|
else_expression = g |
|
|
|
|
|
if isinstance(then_expression, tf.Tensor): |
|
g_shape = copy.copy(then_expression.get_shape()) |
|
elif isinstance(then_expression, tf.IndexedSlices): |
|
g_shape = copy.copy(then_expression.dense_shape) |
|
if condition.dtype != tf.bool: |
|
condition = tf.cast(condition, 'bool') |
|
g = tf.cond(condition, |
|
lambda: then_expression, |
|
lambda: else_expression) |
|
if isinstance(then_expression, tf.Tensor): |
|
g.set_shape(g_shape) |
|
elif isinstance(then_expression, tf.IndexedSlices): |
|
g._dense_shape = g_shape |
|
|
|
return g |
|
nn.OptimizerBase = OptimizerBase |
|
|