def _num_linear_layers(fn):
if fn.mode == cudnn.CUDNN_LSTM:
return 8
elif fn.mode == cudnn.CUDNN_GRU:
return 6
elif fn.mode == cudnn.CUDNN_RNN_RELU:
return 2
elif fn.mode == cudnn.CUDNN_RNN_TANH:
return 2
else:
raise RuntimeError('Unknown mode: {}'.format(fn.mode))
评论列表
文章目录