def forward(ctx, add_batch, batch1, batch2, alpha=1, beta=1, inplace=False):
ctx.alpha = alpha
ctx.beta = beta
ctx.save_for_backward(batch1, batch2)
output = _get_output(ctx, add_batch, inplace=inplace)
return torch.baddbmm(alpha, add_batch, beta,
batch1, batch2, out=output)
评论列表
文章目录