def build_head(self, k=0, units=None):
if units is None:
raise Exception()
config.check()
wscale = config.q_wscale
# Fully connected part of Q-Network
fc_attributes = {}
fc_units = zip(units[:-1], units[1:])
for i, (n_in, n_out) in enumerate(fc_units):
fc_attributes["layer_%i" % i] = LinearHead(n_in, n_out, config.q_k_heads, wscale=wscale)
fc_attributes["batchnorm_%i" % i] = BatchNormalization(n_out)
fc = FullyConnectedNetwork(**fc_attributes)
fc.n_hidden_layers = len(fc_units) - 1
fc.activation_function = config.q_fc_activation_function
fc.apply_dropout = config.q_fc_apply_dropout
if config.use_gpu:
fc.to_gpu()
return fc
评论列表
文章目录