def __call__(self, x):
h = x
for l in self.conv_layers:
h = self.activation(l(h))
# Advantage
batch_size = x.shape[0]
ya = self.a_stream(h)
mean = F.reshape(
F.sum(ya, axis=1) / self.n_actions, (batch_size, 1))
ya, mean = F.broadcast(ya, mean)
ya -= mean
# State value
ys = self.v_stream(h)
ya, ys = F.broadcast(ya, ys)
q = ya + ys
return action_value.DiscreteActionValue(q)
评论列表
文章目录