def get_initial_states(self, nse_input, input_mask=None):
'''
Read input in MMA-NSE will be of shape (batch_size, read_input_length*2, input_dim), a concatenation of
the actual input to this NSE and the output from a different NSE. The latter will be used to initialize
the shared memory. The former will be passed to the read LSTM and also used to initialize the current
memory.
'''
input_length = K.shape(nse_input)[1]
read_input_length = input_length/2
input_to_read = nse_input[:, :read_input_length, :]
initial_shared_memory = K.batch_flatten(nse_input[:, read_input_length:, :])
mem_0 = K.batch_flatten(input_to_read)
o_mask = self.reader.compute_mask(input_to_read, input_mask)
reader_states = self.reader.get_initial_states(nse_input)
initial_states = reader_states + [mem_0, initial_shared_memory]
return initial_states, o_mask
评论列表
文章目录