def test_model_custom_loss():
x = torch.rand(20, 4)
y = torch.rand(20, 10)
model = Model(
Dense(10, input_dim=x.size()[-1]),
Activation('relu'),
Dense(5),
Activation('relu'),
Dense(y.size()[-1])
)
opt = SGD(lr=0.01, momentum=0.9)
def mae(y_true, y_pred):
return torch.mean(torch.abs(y_true - y_pred))
history = model.fit(x, y, loss=mae, optimizer=opt, epochs=10)
assert len(history['loss']) == 10
assert all(type(v) is float for v in history['loss'])
assert history['loss'] == sorted(history['loss'], reverse=True)
评论列表
文章目录