def _set(self, dropout, seed):
if self.state is None and dropout > 0:
dropout_states_size = ctypes.c_long()
check_error(lib.cudnnDropoutGetStatesSize(
self.handle,
ctypes.byref(dropout_states_size)))
self.state = torch.cuda.ByteTensor(dropout_states_size.value)
state_ptr = self.state.data_ptr()
state_size = self.state.size(0)
else:
state_ptr = None
state_size = 0
check_error(lib.cudnnSetDropoutDescriptor(
self,
self.handle,
ctypes.c_float(dropout),
ctypes.c_void_p(state_ptr),
ctypes.c_size_t(state_size),
ctypes.c_ulonglong(seed),
))
self.dropout = dropout
评论列表
文章目录