def __init__(self, cell, zoneout_prob, is_training=True):
if not isinstance(cell, RNNCell):
raise TypeError("The parameter cell is not an RNNCell.")
if isinstance(cell, BasicLSTMCell):
self._tuple = lambda x: LSTMStateTuple(*x)
else:
self._tuple = lambda x: tuple(x)
if (isinstance(zoneout_prob, float) and
not (zoneout_prob >= 0.0 and zoneout_prob <= 1.0)):
raise ValueError("Parameter zoneout_prob must be between 0 and 1: %d"
% zoneout_prob)
self._cell = cell
self._zoneout_prob = zoneout_prob
self.is_training = is_training
评论列表
文章目录