def shards(data, size=25, test=False):
"""
Generator over variables that will be involved in a costly loss computation
such as the softmax. It yields dictionaries of the same form as the input,
where the variables have been splitted in smaller shards and detach from
the graph. It expects the consumer to back propagate through them in shards
of given a size. After all shards are consumed, the generator will take
care of backprop further from the input using the accumulated gradients.
"""
# Inspired by www.github.com/OpenNMT/OpenNMT-py/blob/master/onmt/Loss.py
if test:
yield data
return
detached = dict(detach_vars(data))
splits = ((key, torch.split(v, size)) for key, v in detached.items())
keys, splits = zip(*splits)
for split in zip(*splits):
yield dict(zip(keys, split)) # go and accumulate some loss
inputs, grads = [], []
for key, var in detached.items():
if var.grad is not None:
inputs.append(data[key]), grads.append(var.grad.data)
torch.autograd.backward(inputs, grads, retain_graph=True)
# Initializers
评论列表
文章目录