def __call__(self, inputs, state, scope=None):
"""Run the cell with the declared dropouts."""
if (not isinstance(self._input_keep_prob, float) or
self._input_keep_prob < 1):
do_inputs = dropout(inputs, self._input_keep_prob, seed=self._seed)
inputs = tf.cond(self._is_train, lambda: do_inputs, lambda: inputs)
output, new_state = self._cell(inputs, state)
if (not isinstance(self._output_keep_prob, float) or
self._output_keep_prob < 1):
do_output = dropout(output, self._output_keep_prob, seed=self._seed)
output = tf.cond(self._is_train, lambda: do_output, lambda: output)
return output, new_state
评论列表
文章目录