def test_change_get_loss(self, net_cls, module_cls, data):
from skorch.utils import to_var
class MyNet(net_cls):
# pylint: disable=unused-argument
def get_loss(self, y_pred, y_true, X=None, training=False):
y_true = to_var(y_true, use_cuda=False)
loss_a = torch.abs(y_true.float() - y_pred[:, 1]).mean()
loss_b = ((y_true.float() - y_pred[:, 1]) ** 2).mean()
if training:
self.history.record_batch('loss_a', to_numpy(loss_a)[0])
self.history.record_batch('loss_b', to_numpy(loss_b)[0])
return loss_a + loss_b
X, y = data
net = MyNet(module_cls, max_epochs=1)
net.fit(X, y)
diffs = []
all_losses = net.history[
-1, 'batches', :, ('train_loss', 'loss_a', 'loss_b')]
diffs = [total - a - b for total, a, b in all_losses]
assert np.allclose(diffs, 0, atol=1e-7)
评论列表
文章目录