def word_dropout_mask(X, dropout_rate, reserved_codes=()):
"""
Computes a binary mask across batch examples based on a
bernoulli distribution with mean equal to dropout.
"""
probs = torch.zeros_like(X).float() + dropout_rate
# zero reserved_codes (avoid dropping reserved symbols)
if len(reserved_codes) > 0:
probs[sum((X == x) for x in reserved_codes)] = 0
# return binary mask
return torch.bernoulli(probs).byte()
评论列表
文章目录