def build_network(self, output_dim=1):
config.check()
wscale = config.q_wscale
# Fully connected part of Q-Network
fc_attributes = {}
fc_units = [(34 * config.rl_history_length, config.q_fc_hidden_units[0])]
fc_units += zip(config.q_fc_hidden_units[:-1], config.q_fc_hidden_units[1:])
fc_units += [(config.q_fc_hidden_units[-1], output_dim)]
for i, (n_in, n_out) in enumerate(fc_units):
fc_attributes["layer_%i" % i] = L.Linear(n_in, n_out, wscale=wscale)
fc_attributes["batchnorm_%i" % i] = L.BatchNormalization(n_out)
fc = FullyConnectedNetwork(**fc_attributes)
fc.n_hidden_layers = len(fc_units) - 1
fc.activation_function = config.q_fc_activation_function
fc.apply_batchnorm = config.apply_batchnorm
fc.apply_dropout = config.q_fc_apply_dropout
fc.apply_batchnorm_to_input = config.q_fc_apply_batchnorm_to_input
if config.use_gpu:
fc.to_gpu()
return fc
评论列表
文章目录