layers.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:gail-driver 作者: sisl 项目源码 文件源码
def get_output_for(self, input, **kwargs):
        input_shape = tf.shape(input)
        n_batches = input_shape[0]
        state = tf.tile(
            tf.reshape(self.h0, (1, self.num_units)),
            (n_batches, 1)
        )
        state.set_shape((None, self.num_units))
        if self.horizon is not None:
            outputs = []
            for idx in range(self.horizon):
                output, state = self.gru(
                    input[:, idx, :], state, scope=self.scope)  # self.name)
                outputs.append(tf.expand_dims(output, 1))
            outputs = tf.concat(1, outputs)
            return outputs
        else:
            n_steps = input_shape[1]
            input = tf.reshape(input, tf.pack([n_batches, n_steps, -1]))
            # flatten extra dimensions
            shuffled_input = tf.transpose(input, (1, 0, 2))
            shuffled_input.set_shape((None, None, self.input_shape[-1]))
            hs = tf.scan(
                self.step,
                elems=shuffled_input,
                initializer=state
            )
            shuffled_hs = tf.transpose(hs, (1, 0, 2))
            return shuffled_hs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号