def scalar_gating(
net,
activation=tf.nn.relu,
k_initializer=tf.ones_initializer(),
k_regularizer=None,
k_regularizable=False,
):
# Represent this with shape (1,) instead of as a scalar to get proper
# parameter count from tfprof.
k = tf.get_variable(
'k',
(1,),
initializer=k_initializer,
regularizer=k_regularizer,
trainable=True,
)
# Per the paper, we may specifically not want to regularize k.
k.regularizable = k_regularizable
return activation(k) * net
评论列表
文章目录