def __call__(self, x, t, train=True):
x = chainer.Variable(self.xp.asarray(x), volatile=not train)
t = chainer.Variable(self.xp.asarray(t), volatile=not train)
bs = x.data.shape[0] # batch size
self.clear(bs, train)
# init mean location
l = np.random.uniform(-1, 1, size=(bs,2)).astype(np.float32)
l = chainer.Variable(self.xp.asarray(l), volatile=not train)
# forward n_steps time
sum_ln_pi = 0
self.forward(x, train, action=False, init_l=l)
for i in range(1, self.n_steps):
action = True if (i == self.n_steps - 1) else False
l, ln_pi, y, b = self.forward(x, train, action)
if train: sum_ln_pi += ln_pi
# loss with softmax cross entropy
self.loss_action = F.softmax_cross_entropy(y, t)
self.loss = self.loss_action
self.accuracy = F.accuracy(y, t)
if train:
# reward
conditions = self.xp.argmax(y.data, axis=1) == t.data
r = self.xp.where(conditions, 1., 0.).astype(np.float32)
# squared error between reward and baseline
self.loss_base = F.mean_squared_error(r, b)
self.loss += self.loss_base
# loss with reinforce rule
mean_ln_pi = sum_ln_pi / (self.n_steps - 1)
self.loss_reinforce = F.sum(-mean_ln_pi * (r-b))/bs
self.loss += self.loss_reinforce
return self.loss
评论列表
文章目录