utils.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:seqmod 作者: emanjavacas 项目源码 文件源码
def shards(data, size=25, test=False):
    """
    Generator over variables that will be involved in a costly loss computation
    such as the softmax. It yields dictionaries of the same form as the input,
    where the variables have been splitted in smaller shards and detach from
    the graph. It expects the consumer to back propagate through them in shards
    of given a size. After all shards are consumed, the generator will take
    care of backprop further from the input using the accumulated gradients.
    """
    # Inspired by www.github.com/OpenNMT/OpenNMT-py/blob/master/onmt/Loss.py
    if test:
        yield data
        return

    detached = dict(detach_vars(data))
    splits = ((key, torch.split(v, size)) for key, v in detached.items())
    keys, splits = zip(*splits)

    for split in zip(*splits):
        yield dict(zip(keys, split))  # go and accumulate some loss

    inputs, grads = [], []
    for key, var in detached.items():
        if var.grad is not None:
            inputs.append(data[key]), grads.append(var.grad.data)

    torch.autograd.backward(inputs, grads, retain_graph=True)


# Initializers
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号