def test_contrastive_loss_value(self):
x0_val = Variable(self.x0)
x1_val = Variable(self.x1)
t_val = Variable(self.t)
tml = ContrastiveLoss(margin=self.margin)
loss = tml.forward(x0_val, x1_val, t_val)
self.assertEqual(loss.data.numpy().shape, (1, ))
self.assertEqual(loss.data.numpy().dtype, np.float32)
loss_value = float(loss.data.numpy())
# Compute expected value
loss_expect = 0
for i in range(self.x0.size()[0]):
x0d, x1d, td = self.x0[i], self.x1[i], self.t[i]
d = torch.sum(torch.pow(x0d - x1d, 2))
if td == 1: # similar pair
loss_expect += d
elif td == 0: # dissimilar pair
loss_expect += max(1 - np.sqrt(d), 0)**2
loss_expect /= 2.0 * self.t.size()[0]
print("expected %s got %s" % (loss_expect, loss_value))
self.assertAlmostEqual(loss_expect, loss_value, places=5)
评论列表
文章目录