categorical.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:rltools 作者: sisl 项目源码 文件源码
def _make_actiondist_ops(self, obs_B_H_Df):
        B = tf.shape(obs_B_H_Df)[0]
        H = tf.shape(obs_B_H_Df)[1]
        flatobs_B_H_Df = tf.reshape(obs_B_H_Df, tf.pack([B, H, -1]))
        if self.state_include_action:
            net_in = tf.concat(2, [flatobs_B_H_Df, self._prev_actions_B_H_Da])
            net_shape = (np.prod(self.observation_space.shape) + self.action_space.n,)
        else:
            net_in = flatobs_B_H_Df
            net_shape = (np.prod(self.observation_space.shape),)
        with tf.variable_scope('net'):
            net = nn.GRUNet(net_in, net_shape, self.action_space.n, self.hidden_spec)

        # XXX
        self.hidden_dim = net._hidden_dim

        scores_B_H_Pa = net.output
        actiondist_B_H_Pa = scores_B_H_Pa - tfutil.logsumexp(scores_B_H_Pa, axis=2)

        compute_step_prob = tfutil.function([net.step_input, net.step_prev_hidden],
                                            [net.step_output, net.step_hidden])
        return actiondist_B_H_Pa, net.step_input, compute_step_prob, net.hid_init
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号