def train_loop(sess, train_targets, num_minibatches=1, **loop_params):
"""Define default minibatch training loop.
A training loop that performs minibatching with ``num_minibatches``
minibatches.
Args:
sess (tf.Session): Current tensorflow session.
train_targets (dict): Target operations to be evaluated by ``sess.run``.
By default, ``base.train_from_params`` inserts the following
targets to facilitate minibatching:
* ``__grads__`` (tf.Operation): Accumulates and stores gradients.
* ``optimizer`` (tf.Operation): Applies and zeros gradients.
num_minibatches (int): number of minibatches to use.
**loop_params (mapping): additional, user-defined kwargs to
be used in the training loop.
Returns:
dict: A dictionary containing train targets evaluated by the session.
"""
assert all([required in targets for targets in train_targets
for required in ['__grads__', 'optimizer']])
# Perform minibatching
range_len = (int)(num_minibatches)
for minibatch in range(range_len - 1):
# Accumulate gradient for each minibatch
sess.run([target['__grads__'] for target in train_targets])
# Compute final targets (includes zeroing gradient accumulator variable)
return sess.run(train_targets)
评论列表
文章目录