__init__.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:pytorch 作者: pytorch 项目源码 文件源码
def _set(self, dropout, seed):
        if self.state is None and dropout > 0:
            dropout_states_size = ctypes.c_long()
            check_error(lib.cudnnDropoutGetStatesSize(
                self.handle,
                ctypes.byref(dropout_states_size)))
            self.state = torch.cuda.ByteTensor(dropout_states_size.value)
            state_ptr = self.state.data_ptr()
            state_size = self.state.size(0)
        else:
            state_ptr = None
            state_size = 0

        check_error(lib.cudnnSetDropoutDescriptor(
            self,
            self.handle,
            ctypes.c_float(dropout),
            ctypes.c_void_p(state_ptr),
            ctypes.c_size_t(state_size),
            ctypes.c_ulonglong(seed),
        ))

        self.dropout = dropout
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号