def _cumprod(tensor, axis=0):
"""A custom version of cumprod to prevent NaN gradients when there are zeros in `tensor`
as reported here: https://github.com/tensorflow/tensorflow/issues/3862
:param tensor: tf.Tensor
:return: tf.Tensor
"""
transpose_permutation = None
n_dim = len(tensor.get_shape())
if n_dim > 1 and axis != 0:
if axis < 0:
axis += n_dim
transpose_permutation = np.arange(n_dim)
transpose_permutation[-1], transpose_permutation[0] = 0, axis
tensor = tf.transpose(tensor, transpose_permutation)
def prod(acc, x):
return acc * x
prob = tf.scan(prod, tensor)
tensor = tf.transpose(prob, transpose_permutation)
return tensor
评论列表
文章目录