def _dropout(self, values, salt_prefix, recurrent_noise, keep_prob):
"""Decides whether to perform standard dropout or recurrent dropout."""
if not self._variational_recurrent:
def dropout(i, v):
return nn_ops.dropout(
v, keep_prob=keep_prob, seed=self._gen_seed(salt_prefix, i))
return _enumerated_map_structure(dropout, values)
else:
def dropout(i, v, n):
return self._variational_recurrent_dropout_value(i, v, n, keep_prob)
return _enumerated_map_structure(dropout, values, recurrent_noise)
评论列表
文章目录