def call(self, x, mask=None):
# input_shape = (batch_size, input_length, input_dim). This needs to be defined in build.
initial_read_states = self.get_initial_states(x, mask)
fake_writer_input = K.expand_dims(initial_read_states[0], dim=1) # (batch_size, 1, output_dim)
initial_write_states = self.writer.get_initial_states(fake_writer_input) # h_0 and c_0 of the writer LSTM
initial_states = initial_read_states + initial_write_states
# last_output: (batch_size, output_dim)
# all_outputs: (batch_size, input_length, output_dim)
# last_states:
# last_memory_state: (batch_size, input_length, output_dim)
# last_output
# last_writer_ct
last_output, all_outputs, last_states = self.loop(x, initial_states, mask)
last_memory = last_states[0]
if self.return_mode == "last_output":
return last_output
elif self.return_mode == "all_outputs":
return all_outputs
else:
# return mode is output_and_memory
expanded_last_output = K.expand_dims(last_output, dim=1) # (batch_size, 1, output_dim)
# (batch_size, 1+input_length, output_dim)
return K.concatenate([expanded_last_output, last_memory], axis=1)
评论列表
文章目录