def deserialize(self, state):
# Deserialize state from previous timestep.
M0 = tf.slice(
state,
[0, 0],
[-1, self.mem_nrows * self.mem_ncols],
)
M0 = tf.reshape(M0, [-1, self.mem_nrows, self.mem_ncols])
state_idx = self.mem_nrows * self.mem_ncols
# Deserialize read weights from previous time step.
read_w0s = []
for i in xrange(self.n_heads):
# Number of weights == Rows of memory matrix
w0 = tf.slice(state, [0, state_idx], [-1, self.mem_nrows])
read_w0s.append(w0)
state_idx += self.mem_nrows
# Do the same for write heads.
write_w0s = []
for _ in xrange(self.n_heads):
w0 = tf.slice(state, [0, state_idx], [-1, self.mem_nrows])
write_w0s.append(w0)
state_idx += self.mem_nrows
tf.Assert(
tf.equal(state_idx, tf.shape(state)[1]),
[tf.shape(state)],
)
return M0, write_w0s, read_w0s
评论列表
文章目录