def init_dropout_descriptor(fn, handle):
dropout_desc = cudnn.DropoutDescriptor()
dropout_states_size = ctypes.c_long()
check_error(cudnn.lib.cudnnDropoutGetStatesSize(
handle,
ctypes.byref(dropout_states_size)))
dropout_states = torch.cuda.ByteTensor(dropout_states_size.value)
dropout_desc.set(
handle,
fn.dropout,
dropout_states,
fn.seed
)
return dropout_desc
评论列表
文章目录