embedding.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:seqmod 作者: emanjavacas 项目源码 文件源码
def word_dropout_mask(X, dropout_rate, reserved_codes=()):
    """
    Computes a binary mask across batch examples based on a
    bernoulli distribution with mean equal to dropout.
    """
    probs = torch.zeros_like(X).float() + dropout_rate
    # zero reserved_codes (avoid dropping reserved symbols)
    if len(reserved_codes) > 0:
        probs[sum((X == x) for x in reserved_codes)] = 0
    # return binary mask
    return torch.bernoulli(probs).byte()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号