nn.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:rltools 作者: sisl 项目源码 文件源码
def output(self):
        """Iterate through hidden states to get outputs for all"""
        input_shape = tf.shape(self._input_B_T_Di)
        input = tf.reshape(self._input_B_T_Di, tf.pack([input_shape[0], input_shape[1], -1]))
        h0s = tf.tile(tf.reshape(self.h0, (1, self._hidden_units)), (input_shape[0], 1))
        # Flatten extra dimension
        shuffled_input = tf.transpose(input, (1, 0, 2))
        hs = tf.scan(self.step, elems=shuffled_input, initializer=h0s)
        shuffled_hs = tf.transpose(hs, (1, 0, 2))
        return shuffled_hs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号