def backward_weight(fn, input, hx, output, weight, grad_weight):
with torch.cuda.device_of(input):
handle = cudnn.get_handle()
if fn.mode == cudnn.CUDNN_LSTM:
hx, cx = hx
else:
cx = None
if fn.batch_first:
input = input.transpose(1, 2)
output = output.transpose(1, 2)
input_size = _input_size(fn)
hidden_size = _hidden_size(fn)
if not fn.train:
raise RuntimeError('backward_weight can only be called when training!')
if fn.dropout != 0 and lib.version < 5103:
raise RuntimeError('dropout supported only in cudnn v 5.1 and above')
if tuple(input.size()) != input_size:
raise RuntimeError('Expected input size {}, got {}'.format(
input_size, tuple(input.size())))
if not fn.train:
raise RuntimeError('backward_weight can only be called when training!')
if tuple(hx.size()) != hidden_size:
raise RuntimeError('Expected input size {}, got {}'.format(
hidden_size, hx.size()))
x = input.contiguous()
y = output
dw = fn.weight_buf.new().resize_as_(fn.weight_buf).zero_()
check_error(cudnn.lib.cudnnRNNBackwardWeights(
handle,
fn.rnn_desc,
fn.seq_length,
fn.x_descs, ctypes.c_void_p(x.data_ptr()),
fn.hx_desc, ctypes.c_void_p(hx.data_ptr()),
fn.y_descs, ctypes.c_void_p(y.data_ptr()),
ctypes.c_void_p(fn.workspace.data_ptr()), fn.workspace.size(0),
fn.w_desc, ctypes.c_void_p(dw.data_ptr()),
ctypes.c_void_p(fn.reserve.data_ptr()), fn.reserve.size(0)
))
# copy the weights from the weight_buf into grad_weight
grad_params = get_parameters(fn, handle, dw)
_copyParams(grad_params, grad_weight)
return grad_weight
评论列表
文章目录