def forward(ctx, input, lower, upper, train, inplace):
ctx.lower = lower
ctx.upper = upper
ctx.train = train
ctx.inplace = inplace
ctx._backend = type2backend[type(input)]
if ctx.inplace:
ctx.mark_dirty(input)
output = input
else:
output = input.new(input.size())
ctx.noise = input.new()
ctx._backend.RReLU_updateOutput(
ctx._backend.library_state,
input,
output,
ctx.noise,
ctx.lower,
ctx.upper,
ctx.train,
ctx.inplace,
torch.default_generator if not input.is_cuda else 0
)
ctx.save_for_backward(input)
return output
评论列表
文章目录