rnn.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:pytorch-dist 作者: apaszke 项目源码 文件源码
def init_dropout_descriptor(fn, handle):
    dropout_desc = cudnn.DropoutDescriptor()

    dropout_states_size = ctypes.c_long()
    check_error(cudnn.lib.cudnnDropoutGetStatesSize(
        handle,
        ctypes.byref(dropout_states_size)))

    dropout_states = torch.cuda.ByteTensor(dropout_states_size.value)
    dropout_desc.set(
        handle,
        fn.dropout,
        dropout_states,
        fn.seed
    )
    return dropout_desc
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号