def do_checkpoint(prefix):
"""Callback to checkpoint the model to prefix every epoch.
Parameters
----------
prefix : str
The file prefix to checkpoint to
Returns
-------
callback : function
The callback function that can be passed as iter_end_callback to fit.
"""
def _callback(iter_no, sym, arg, aux):
#if config.TRAIN.BBOX_NORMALIZATION_PRECOMPUTED:
# print "save model with mean/std"
# num_classes = len(arg['bbox_pred_bias'].asnumpy()) / 4
# means = np.tile(np.array(config.TRAIN.BBOX_MEANS), (1, num_classes))
# stds = np.tile(np.array(config.TRAIN.BBOX_STDS), (1, num_classes))
# arg['bbox_pred_weight'] = (arg['bbox_pred_weight'].T * mx.nd.array(stds)).T
# arg['bbox_pred_bias'] = arg['bbox_pred_bias'] * mx.nd.array(np.squeeze(stds)) + \
# mx.nd.array(np.squeeze(means))
"""The checkpoint function."""
save_checkpoint(prefix, iter_no + 1, sym, arg, aux)
return _callback
评论列表
文章目录