def apply(self, is_train, x, mask=None): _, d1, _, d2 = x.shape.as_list() w = tf.get_variable("w", (d1, d2, self.n_out), dtype=tf.float32) return tf.tensordot(x, w, [[1, 3], [0, 1]])