def _write(self, w, F_x, F_y, gamma):
return tf.reshape(tf.batch_matmul(tf.transpose(F_y, [0,2,1]),
tf.batch_matmul(tf.reshape(w, [-1,self.N,self.N]), F_x)),
[-1,self.write_dim])*tf.reshape(1./gamma, [-1,1])
文章目录