train_agent_chainer.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:gym-malware 作者: endgameinc 项目源码 文件源码
def __init__(self, obs_size, n_actions, n_hidden_channels=[1024,256]):
        super(QFunction,self).__init__()
        net = []
        inpdim = obs_size
        for i,n_hid in enumerate(n_hidden_channels):
            net += [ ('l{}'.format(i), L.Linear( inpdim, n_hid ) ) ]
            net += [ ('norm{}'.format(i), L.BatchNormalization( n_hid ) ) ]
            net += [ ('_act{}'.format(i), F.relu ) ]
            inpdim = n_hid

        net += [('output', L.Linear( inpdim, n_actions) )]

        with self.init_scope():
            for n in net:
                if not n[0].startswith('_'):
                    setattr(self, n[0], n[1])

        self.forward = net
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号