switchable_dropout_wrapper.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:paraphrase-id-tensorflow 作者: nelson-liu 项目源码 文件源码
def __call__(self, inputs, state, scope=None):
        # Get the dropped-out outputs and state
        outputs_do, new_state_do = super(SwitchableDropoutWrapper,
                                         self).__call__(
                                             inputs, state, scope=scope)
        tf.get_variable_scope().reuse_variables()
        # Get the un-dropped-out outputs and state
        outputs, new_state = self._cell(inputs, state, scope)

        # Set the outputs and state to be the dropped out version if we are
        # training, and no dropout if we are not training.
        outputs = tf.cond(self.is_train, lambda: outputs_do,
                          lambda: outputs * (self._output_keep_prob))
        if isinstance(state, tuple):
            new_state = state.__class__(
                *[tf.cond(self.is_train, lambda: new_state_do_i,
                          lambda: new_state_i)
                  for new_state_do_i, new_state_i in
                  zip(new_state_do, new_state)])
        else:
            new_state = tf.cond(self.is_train, lambda: new_state_do,
                                lambda: new_state)
        return outputs, new_state
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号