model.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:reinforcement-learning 作者: musyoku 项目源码 文件源码
def build_head(self, k=0, units=None):
        if units is None:
            raise Exception()
        config.check()
        wscale = config.q_wscale

        # Fully connected part of Q-Network
        fc_attributes = {}
        fc_units = zip(units[:-1], units[1:])

        for i, (n_in, n_out) in enumerate(fc_units):
            fc_attributes["layer_%i" % i] = LinearHead(n_in, n_out, config.q_k_heads, wscale=wscale)
            fc_attributes["batchnorm_%i" % i] = BatchNormalization(n_out)

        fc = FullyConnectedNetwork(**fc_attributes)
        fc.n_hidden_layers = len(fc_units) - 1
        fc.activation_function = config.q_fc_activation_function
        fc.apply_dropout = config.q_fc_apply_dropout
        if config.use_gpu:
            fc.to_gpu()
        return fc
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号