T_LSTM_AE.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:T-LSTM 作者: illidanlab 项目源码 文件源码
def get_decoder_states(self):
        batch_size = tf.shape(self.input)[0]
        seq_length = tf.shape(self.input)[1]
        scan_input_ = tf.transpose(self.input, perm=[2, 0, 1])
        scan_input_ = tf.transpose(scan_input_)  # scan input is [seq_length x batch_size x input_dim]
        z = tf.zeros([1, batch_size, self.input_dim], dtype=tf.float32)
        scan_input = tf.concat([scan_input_,z],0)
        scan_input = tf.slice(scan_input, [1,0,0],[seq_length ,batch_size, self.input_dim])
        scan_input = tf.reverse(scan_input, [0])#tf.reverse(scan_input, [True, False, False])
        scan_time_ = tf.transpose(self.time)  # scan_time [seq_length x batch_size]
        z2 = tf.zeros([1, batch_size], dtype=tf.float32)
        scan_time = tf.concat([scan_time_, z2],0)
        scan_time = tf.slice(scan_time, [1,0],[seq_length ,batch_size])
        scan_time = tf.reverse(scan_time, [0])#tf.reverse(scan_time, [True, False])
        initial_hidden, initial_cell = self.get_representation()
        ini_state_cell = tf.stack([initial_hidden, initial_cell])
        # make scan_time [seq_length x batch_size x 1]
        scan_time = tf.reshape(scan_time, [tf.shape(scan_time)[0], tf.shape(scan_time)[1], 1])
        concat_input = tf.concat([scan_time, scan_input],2)  # [seq_length x batch_size x input_dim+1]
        packed_hidden_states = tf.scan(self.T_LSTM_Decoder_Unit, concat_input, initializer=ini_state_cell, name='decoder_states')
        all_decoder_states = packed_hidden_states[:, 0, :, :]
        return all_decoder_states
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号