def shards(state, shard_size, eval=False):
"""
Args:
state: A dictionary which corresponds to the output of
*LossCompute.make_shard_state(). The values for
those keys are Tensor-like or None.
shard_size: The maximum size of the shards yielded by the model.
eval: If True, only yield the state, nothing else.
Otherwise, yield shards.
Yields:
Each yielded shard is a dict.
Side effect:
After the last shard, this function does back-propagation.
"""
if eval:
yield state
else:
# non_none: the subdict of the state dictionary where the values
# are not None.
non_none = dict(filter_shard_state(state))
# Now, the iteration:
# state is a dictionary of sequences of tensor-like but we
# want a sequence of dictionaries of tensors.
# First, unzip the dictionary into a sequence of keys and a
# sequence of tensor-like sequences.
keys, values = zip(*((k, torch.split(v, shard_size))
for k, v in non_none.items()))
# Now, yield a dictionary for each shard. The keys are always
# the same. values is a sequence of length #keys where each
# element is a sequence of length #shards. We want to iterate
# over the shards, not over the keys: therefore, the values need
# to be re-zipped by shard and then each shard can be paired
# with the keys.
for shard_tensors in zip(*values):
yield dict(zip(keys, shard_tensors))
# Assumed backprop'd
variables = ((state[k], v.grad.data) for k, v in non_none.items()
if isinstance(v, Variable) and v.grad is not None)
inputs, grads = zip(*variables)
torch.autograd.backward(inputs, grads)
评论列表
文章目录