def on_epoch_done(epoch, n, o, loss, acc, valid_loss, valid_acc, test_loss, test_acc):
error = 100 * (1 - acc)
valid_error = 100 * (1 - valid_acc)
test_error = 100 * (1 - test_acc)
print('epoch {} done'.format(epoch))
print('train loss: {} error: {}'.format(loss, error))
print('valid loss: {} error: {}'.format(valid_loss, valid_error))
print('test loss: {} error: {}'.format(test_loss, test_error))
if valid_error < state['best_valid_error']:
serializers.save_npz('{}.model'.format(model_prefix), n)
serializers.save_npz('{}.state'.format(model_prefix), o)
state['best_valid_error'] = valid_error
state['best_test_error'] = test_error
if args.save_iter > 0 and (epoch + 1) % args.save_iter == 0:
serializers.save_npz('{}_{}.model'.format(model_prefix, epoch + 1), n)
serializers.save_npz('{}_{}.state'.format(model_prefix, epoch + 1), o)
# prevent divergence when using identity mapping model
if args.model == 'identity_mapping' and epoch < 9:
o.lr = 0.01 + 0.01 * (epoch + 1)
# if len(lr_decay_iter) == 1 and (epoch + 1) % lr_decay_iter[0] == 0 or epoch + 1 in lr_decay_iter:
# Note, "lr_decay_iter" should be a list object to store a training schedule,
# However, to keep up with the Python3.5, I changed to an integer value...
if (epoch + 1) % args.lr_decay_iter == 0 and epoch > 1:
if hasattr(optimizer, 'alpha'):
o.alpha *= 0.1
else:
o.lr *= 0.1
clock = time.clock()
print('elapsed time: {}'.format(clock - state['clock']))
state['clock'] = clock
with open(log_file_path, 'a') as f:
f.write('{},{},{},{},{},{},{}\n'.format(epoch + 1, loss, error, valid_loss, valid_error, test_loss, test_error))
评论列表
文章目录