def get_cudnn_mode(mode):
if mode == 'RNN_RELU':
return cudnn.CUDNN_RNN_RELU
elif mode == 'RNN_TANH':
return cudnn.CUDNN_RNN_TANH
elif mode == 'LSTM':
return cudnn.CUDNN_LSTM
elif mode == 'GRU':
return cudnn.CUDNN_GRU
else:
raise Exception("Unknown mode: {}".format(mode))
评论列表
文章目录