diet.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:tensor2tensor 作者: tensorflow 项目源码 文件源码
def diet_expert(x, hidden_size, params):
  """A two-layer feed-forward network with relu activation on hidden layer.

  Uses diet variables.
  Recompuets hidden layer on backprop to save activation memory.

  Args:
    x: a Tensor with shape [batch, io_size]
    hidden_size: an integer
    params: a diet variable HParams object.

  Returns:
    a Tensor with shape [batch, io_size]
  """

  @fn_with_diet_vars(params)
  def diet_expert_internal(x):
    dim = x.get_shape().as_list()[-1]
    h = tf.layers.dense(x, hidden_size, activation=tf.nn.relu, use_bias=False)
    y = tf.layers.dense(h, dim, use_bias=False)
    y *= tf.rsqrt(tf.to_float(dim * hidden_size))
    return y

  return diet_expert_internal(x)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号