benchmark.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:deepspeech.pytorch 作者: SeanNaren 项目源码 文件源码
def iteration(input_data):
    target = torch.IntTensor(int(batch_size * ((seconds * 100) / 2))).fill_(1)  # targets, align half of the audio
    target_size = torch.IntTensor(batch_size).fill_(int((seconds * 100) / 2))
    input_percentages = torch.IntTensor(batch_size).fill_(1)

    inputs = Variable(input_data, requires_grad=False)
    target_sizes = Variable(target_size, requires_grad=False)
    targets = Variable(target, requires_grad=False)
    start = time.time()
    out = model(inputs)
    out = out.transpose(0, 1)  # TxNxH

    seq_length = out.size(0)
    sizes = Variable(input_percentages.mul_(int(seq_length)).int(), requires_grad=False)
    loss = criterion(out, targets, sizes, target_sizes)
    loss = loss / inputs.size(0)  # average the loss by minibatch
    # compute gradient
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    torch.cuda.synchronize()
    end = time.time()
    del loss
    del out
    return start, end
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号