def flatten_parameters(self):
"""Resets parameter data pointer so that they can use faster code paths.
Right now, this works only if the module is on the GPU and cuDNN is enabled.
Otherwise, it's a no-op.
"""
any_param = next(self.parameters()).data
if not any_param.is_cuda or not torch.backends.cudnn.is_acceptable(any_param):
self._data_ptrs = []
return
with torch.cuda.device_of(any_param):
# This is quite ugly, but it allows us to reuse the cuDNN code without larger
# modifications. It's really a low-level API that doesn't belong in here, but
# let's make this exception.
from torch.backends.cudnn import rnn
from torch.backends import cudnn
from torch.nn._functions.rnn import CudnnRNN
handle = cudnn.get_handle()
with warnings.catch_warnings(record=True):
fn = CudnnRNN(
self.mode,
self.input_size,
self.hidden_size,
num_layers=self.num_layers,
batch_first=self.batch_first,
dropout=self.dropout,
train=self.training,
bidirectional=self.bidirectional,
dropout_state=self.dropout_state,
)
# Initialize descriptors
fn.datatype = cudnn._typemap[any_param.type()]
fn.x_descs = cudnn.descriptor(any_param.new(1, self.input_size), 1)
fn.rnn_desc = rnn.init_rnn_descriptor(fn, handle)
# Allocate buffer to hold the weights
self._param_buf_size = rnn.get_num_weights(handle, fn.rnn_desc, fn.x_descs[0], fn.datatype)
fn.weight_buf = any_param.new(self._param_buf_size).zero_()
fn.w_desc = rnn.init_weight_descriptor(fn, fn.weight_buf)
# Slice off views into weight_buf
params = rnn.get_parameters(fn, handle, fn.weight_buf)
all_weights = [[p.data for p in l] for l in self.all_weights]
# Copy weights and update their storage
rnn._copyParams(all_weights, params)
for orig_layer_param, new_layer_param in zip(all_weights, params):
for orig_param, new_param in zip(orig_layer_param, new_layer_param):
orig_param.set_(new_param.view_as(orig_param))
self._data_ptrs = list(p.data.data_ptr() for p in self.parameters())