def get_num_weights(handle, rnn_desc, x_desc, datatype):
weight_size = ctypes.c_long()
check_error(cudnn.lib.cudnnGetRNNParamsSize(
handle,
rnn_desc,
x_desc,
ctypes.byref(weight_size),
datatype
))
elem_size = cudnn._sizeofmap[datatype]
assert weight_size.value % elem_size == 0
return weight_size.value // elem_size
评论列表
文章目录