core.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:thinstack-rl 作者: hans 项目源码 文件源码
def _thin_stack_update_gradient(op, stack_grad, *rest):
    stack = op.inputs[2]
    batch_size = op.inputs[4].get_shape().as_list()[0]
    t = op.get_attr("timestep")

    # We usually slice off the head of the stack output in feedforward and
    # send it off to downstream computation. The Slice feedforward op will
    # generate a sparse gradient in the backward pass. Nix this sparsity
    # at the very start.
    if isinstance(stack_grad, ops.IndexedSlices):
        # Trick: re-use our stack structure to store new gradients.
        # Recover the original stack variable from the lookup/update chain.
        stack = _fetch_stack(stack)

        stack = tf.assign(stack, tf.zeros_like(stack))
        stack = tf.scatter_update(stack, stack_grad.indices, stack_grad.values)
        stack_grad = stack

    with tf.control_dependencies([stack_grad]):
        input_grad = tf.slice(stack_grad, [t * batch_size, 0], [batch_size, -1])

    return input_grad, None, stack_grad, None, None, None
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号