layers.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:rllabplusplus 作者: shaneshixiang 项目源码 文件源码
def get_output_for(self, input, **kwargs):
        input_shape = tf.shape(input)
        n_batches = input_shape[0]
        h0s = tf.tile(
            tf.reshape(self.h0, (1, self.num_units)),
            (n_batches, 1)
        )
        h0s.set_shape((None, self.num_units))
        c0s = tf.tile(
            tf.reshape(self.c0, (1, self.num_units)),
            (n_batches, 1)
        )
        c0s.set_shape((None, self.num_units))
        state = (c0s, h0s)
        if self.horizon is not None:
            outputs = []
            for idx in range(self.horizon):
                output, state = self.lstm(input[:, idx, :], state, scope=self.scope)  # self.name)
                outputs.append(tf.expand_dims(output, 1))
            outputs = tf.concat(axis=1, values=outputs)
            return outputs
        else:
            n_steps = input_shape[1]
            input = tf.reshape(input, tf.stack([n_batches, n_steps, -1]))
            # flatten extra dimensions
            shuffled_input = tf.transpose(input, (1, 0, 2))
            shuffled_input.set_shape((None, None, self.input_shape[-1]))
            hcs = tf.scan(
                self.step,
                elems=shuffled_input,
                initializer=tf.concat(axis=1, values=[h0s, c0s]),
            )
            shuffled_hcs = tf.transpose(hcs, (1, 0, 2))
            shuffled_hs = shuffled_hcs[:, :, :self.num_units]
            shuffled_cs = shuffled_hcs[:, :, self.num_units:]
            return shuffled_hs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号