def get_decoder_states(self):
batch_size = tf.shape(self.input)[0]
seq_length = tf.shape(self.input)[1]
scan_input_ = tf.transpose(self.input, perm=[2, 0, 1])
scan_input_ = tf.transpose(scan_input_) # scan input is [seq_length x batch_size x input_dim]
z = tf.zeros([1, batch_size, self.input_dim], dtype=tf.float32)
scan_input = tf.concat([scan_input_,z],0)
scan_input = tf.slice(scan_input, [1,0,0],[seq_length ,batch_size, self.input_dim])
scan_input = tf.reverse(scan_input, [0])#tf.reverse(scan_input, [True, False, False])
scan_time_ = tf.transpose(self.time) # scan_time [seq_length x batch_size]
z2 = tf.zeros([1, batch_size], dtype=tf.float32)
scan_time = tf.concat([scan_time_, z2],0)
scan_time = tf.slice(scan_time, [1,0],[seq_length ,batch_size])
scan_time = tf.reverse(scan_time, [0])#tf.reverse(scan_time, [True, False])
initial_hidden, initial_cell = self.get_representation()
ini_state_cell = tf.stack([initial_hidden, initial_cell])
# make scan_time [seq_length x batch_size x 1]
scan_time = tf.reshape(scan_time, [tf.shape(scan_time)[0], tf.shape(scan_time)[1], 1])
concat_input = tf.concat([scan_time, scan_input],2) # [seq_length x batch_size x input_dim+1]
packed_hidden_states = tf.scan(self.T_LSTM_Decoder_Unit, concat_input, initializer=ini_state_cell, name='decoder_states')
all_decoder_states = packed_hidden_states[:, 0, :, :]
return all_decoder_states
评论列表
文章目录