def step(self, input_t, states):
reader_states = states[:2]
flattened_mem_tm1, flattened_shared_mem_tm1 = states[2:4]
writer_h_tm1, writer_c_tm1 = states[4:]
input_mem_shape = K.shape(flattened_mem_tm1)
mem_shape = (input_mem_shape[0], input_mem_shape[1]/self.output_dim, self.output_dim)
mem_tm1 = K.reshape(flattened_mem_tm1, mem_shape)
shared_mem_tm1 = K.reshape(flattened_shared_mem_tm1, mem_shape)
reader_constants = self.reader.get_constants(input_t)
reader_states += reader_constants
o_t, [_, reader_c_t] = self.reader.step(input_t, reader_states)
z_t, m_rt = self.summarize_memory(o_t, mem_tm1)
shared_z_t, shared_m_rt = self.summarize_memory(o_t, shared_mem_tm1)
c_t = self.compose_memory_and_output([o_t, m_rt, shared_m_rt])
# Collecting the necessary variables to directly call writer's step function.
writer_constants = self.writer.get_constants(c_t) # returns dropouts for W and U (all 1s, see init)
writer_states = [writer_h_tm1, writer_c_tm1] + writer_constants
# Making a call to writer's step function, Equation 5
h_t, [_, writer_c_t] = self.writer.step(c_t, writer_states) # h_t, writer_c_t: (batch_size, output_dim)
mem_t = self.update_memory(z_t, h_t, mem_tm1)
shared_mem_t = self.update_memory(shared_z_t, h_t, shared_mem_tm1)
return h_t, [o_t, reader_c_t, K.batch_flatten(mem_t), K.batch_flatten(shared_mem_t), h_t, writer_c_t]
评论列表
文章目录