diet.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:tefla 作者: openAGI 项目源码 文件源码
def diet_expert(x, hidden_size, params):
    """A two-layer feed-forward network with relu activation on hidden layer.

    Uses diet variables.
    Recompuets hidden layer on backprop to save activation memory.

    Args:
      x: a Tensor with shape [batch, io_size]
      hidden_size: an integer
      params: a diet variable HParams object.

    Returns:
      a Tensor with shape [batch, io_size]
    """

    @fn_with_diet_vars(params)
    def diet_expert_internal(x):
        dim = x.get_shape().as_list()[-1]
        h = tf.layers.dense(
            x, hidden_size, activation=tf.nn.relu, use_bias=False)
        y = tf.layers.dense(h, dim, use_bias=False)
        y *= tf.rsqrt(tf.to_float(dim * hidden_size))
        return y

    return diet_expert_internal(x)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号