model.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:self-driving-cars 作者: musyoku 项目源码 文件源码
def build_network(self, output_dim=1):
        config.check()
        wscale = config.q_wscale

        # Fully connected part of Q-Network
        fc_attributes = {}
        fc_units = [(34 * config.rl_history_length, config.q_fc_hidden_units[0])]
        fc_units += zip(config.q_fc_hidden_units[:-1], config.q_fc_hidden_units[1:])
        fc_units += [(config.q_fc_hidden_units[-1], output_dim)]

        for i, (n_in, n_out) in enumerate(fc_units):
            fc_attributes["layer_%i" % i] = L.Linear(n_in, n_out, wscale=wscale)
            fc_attributes["batchnorm_%i" % i] = L.BatchNormalization(n_out)

        fc = FullyConnectedNetwork(**fc_attributes)
        fc.n_hidden_layers = len(fc_units) - 1
        fc.activation_function = config.q_fc_activation_function
        fc.apply_batchnorm = config.apply_batchnorm
        fc.apply_dropout = config.q_fc_apply_dropout
        fc.apply_batchnorm_to_input = config.q_fc_apply_batchnorm_to_input
        if config.use_gpu:
            fc.to_gpu()
        return fc
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号