def apply(self, is_train, x, mask=None):
out = self.other.apply(is_train, x, mask)
w = tf.get_variable("project_w", (x.shape.as_list()[-1], out.shape.as_list()[-1]))
return out + tf.tensordot(x, w, axes=[[len(x.shape)-1], [0]])
评论列表
文章目录