def __call__(self, inputs, state, scope=None):
if isinstance(self.state_size, tuple) != isinstance(self._zoneout_prob, tuple):
raise TypeError("Subdivided states need subdivided zoneouts.")
if isinstance(self.state_size, tuple) and len(tuple(self.state_size)) != len(tuple(self._zoneout_prob)):
raise ValueError("State and zoneout need equally many parts.")
output, new_state = self._cell(inputs, state, scope)
if isinstance(self.state_size, tuple):
if self.is_training:
new_state = self._tuple([(1 - state_part_zoneout_prob) * dropout(
new_state_part - state_part, (1 - state_part_zoneout_prob)) + state_part
for new_state_part, state_part, state_part_zoneout_prob in
zip(new_state, state, self._zoneout_prob)])
else:
new_state = self._tuple([state_part_zoneout_prob * state_part + (1 - state_part_zoneout_prob) * new_state_part
for new_state_part, state_part, state_part_zoneout_prob in
zip(new_state, state, self._zoneout_prob)])
else:
if self.is_training:
new_state = (1 - state_part_zoneout_prob) * dropout(
new_state_part - state_part, (1 - state_part_zoneout_prob)) + state_part
else:
new_state = state_part_zoneout_prob * state_part + (1 - state_part_zoneout_prob) * new_state_part
return output, new_state
# # Wrap your cells like this
# cell = ZoneoutWrapper(tf.nn.rnn_cell.LSTMCell(hidden_units, initializer=random_uniform(), state_is_tuple=True),
# zoneout_prob=(z_prob_cells, z_prob_states))
评论列表
文章目录