def __call__(self, inputs, state, scope=None):
"""Memory grid (MemGrid) with nunits cells."""
with tf.variable_scope(scope or type(self).__name__): # "MemGrid"
with tf.variable_scope("Gates"): # Reset gate and update gate.
# We start with bias of 1.0 to not reset and not update.
r, u = tf.split(self.unbalance_linear([inputs, self._memory],
2 * self._mem_dim, True, 1.0), 2, 2)
r, u = sigmoid(r), sigmoid(u)
with tf.variable_scope("Candidate"):
c = self._activation(self.unbalance_linear([inputs, r * self._memory],
self._mem_dim, True))
# Decide which line to write: line weights
l = att_weight(inputs, tf.concat([c, self._memory], 2), self.echocell, scope="Line_weights")
l = tf.reshape(l, [self._batch_size, self._mem_size, 1])
t_memory = u * self._memory + (1 - u) * c
self._memory = self._memory * (1 - l) + t_memory * l
# hl = att_weight(inputs, self._memory, echocell, scope="hidden_lw")
# hl = tf.reshape(hl, [self._batch_size, self._mem_size, 1])
# output = tf.reduce_sum(hl * self._memory, 1)
output = tf.reduce_sum(l * self._memory, 1)
output = tf.reshape(output, [self._batch_size, self._mem_dim])
return output, state
评论列表
文章目录