def write_attention(self, hidden_layer):
with tf.variable_scope("writeW", reuse=self.share_parameters):
w = dense(hidden_layer, self.n_hidden, self.attention_n*self.attention_n*self.num_colors)
w = tf.reshape(w, [self.batch_size, self.attention_n, self.attention_n, self.num_colors])
w_t = tf.transpose(w, perm=[3,0,1,2])
Fx, Fy, gamma = self.attn_window("write", hidden_layer)
# color1, color2, color3, color1, color2, color3, etc.
w_array = tf.reshape(w_t, [self.num_colors * self.batch_size, self.attention_n, self.attention_n])
Fx_array = tf.concat(0, [Fx, Fx, Fx])
Fy_array = tf.concat(0, [Fy, Fy, Fy])
Fyt = tf.transpose(Fy_array, perm=[0,2,1])
# [vert, attn_n] * [attn_n, attn_n] * [attn_n, horiz]
wr = tf.batch_matmul(Fyt, tf.batch_matmul(w_array, Fx_array))
sep_colors = tf.reshape(wr, [self.batch_size, self.num_colors, self.img_size**2])
wr = tf.reshape(wr, [self.num_colors, self.batch_size, self.img_size, self.img_size])
wr = tf.transpose(wr, [1,2,3,0])
wr = tf.reshape(wr, [self.batch_size, self.img_size * self.img_size * self.num_colors])
return wr * tf.reshape(1.0/gamma, [-1, 1])
评论列表
文章目录