def exp_mask(logits, mask): return torch.add(logits.data, (torch.ones(mask.size()) - mask.type(dtype)) * VERY_NEGATIVE_NUMBER)