state_q_functions.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:chainerrl 作者: chainer 项目源码 文件源码
def __init__(self, n_input_channels, n_dim_action, n_hidden_channels,
                 n_hidden_layers, action_space, scale_mu=True):
        self.n_input_channels = n_input_channels
        self.n_hidden_layers = n_hidden_layers
        self.n_hidden_channels = n_hidden_channels

        assert action_space is not None
        self.scale_mu = scale_mu
        self.action_space = action_space

        super().__init__()
        with self.init_scope():
            hidden_layers = []
            assert n_hidden_layers >= 1
            hidden_layers.append(L.Linear(n_input_channels, n_hidden_channels))
            for i in range(n_hidden_layers - 1):
                hidden_layers.append(
                    L.Linear(n_hidden_channels, n_hidden_channels))
            self.hidden_layers = chainer.ChainList(*hidden_layers)

            self.v = L.Linear(n_hidden_channels, 1)
            self.mu = L.Linear(n_hidden_channels, n_dim_action)
            self.mat_diag = L.Linear(n_hidden_channels, n_dim_action)
            non_diag_size = n_dim_action * (n_dim_action - 1) // 2
            if non_diag_size > 0:
                self.mat_non_diag = L.Linear(n_hidden_channels, non_diag_size)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号