def diet_expert(x, hidden_size, params):
"""A two-layer feed-forward network with relu activation on hidden layer.
Uses diet variables.
Recompuets hidden layer on backprop to save activation memory.
Args:
x: a Tensor with shape [batch, io_size]
hidden_size: an integer
params: a diet variable HParams object.
Returns:
a Tensor with shape [batch, io_size]
"""
@fn_with_diet_vars(params)
def diet_expert_internal(x):
dim = x.get_shape().as_list()[-1]
h = tf.layers.dense(
x, hidden_size, activation=tf.nn.relu, use_bias=False)
y = tf.layers.dense(h, dim, use_bias=False)
y *= tf.rsqrt(tf.to_float(dim * hidden_size))
return y
return diet_expert_internal(x)
评论列表
文章目录