def apply(self, is_train, x, mask=None):
s = x.shape.as_list()[1]
w = tf.get_variable("w", (s,), dtype=tf.float32,
initializer=tf.constant_initializer(s / 3.0))
b = tf.get_variable("b", (), dtype=tf.float32,
initializer=tf.zeros_initializer())
return tf.tensordot(x, w, [[1], [0]]) + b
评论列表
文章目录