ACTDAModel.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:act-rte-inference 作者: DeNeutoy 项目源码 文件源码
def inference_step(self,batch_mask, prob_compare,prob,counter, state, premise, hypothesis, acc_states):

        if self.config.keep_prob < 1.0 and self.is_training:
            premise = tf.nn.dropout(premise, self.config.keep_prob)
            hypothesis = tf.nn.dropout(hypothesis,self.config.keep_prob)

        hyp_attn = self.attention(state, hypothesis, "hyp_attn")
        state_for_premise = tf.concat(1, [state, hyp_attn])
        prem_attn = self.attention(state_for_premise, premise, "prem_attn")
        state_for_gates = tf.concat(1, [state, hyp_attn ,prem_attn, prem_attn * hyp_attn])

        hyp_gate = self.gate_mechanism(state_for_gates, "hyp_gate")
        prem_gate = self.gate_mechanism(state_for_gates, "prem_gate")

        input = tf.concat(1, [hyp_gate * hyp_attn, prem_gate * prem_attn])
        output, new_state = self.inference_cell(input,state)

        with tf.variable_scope('sigmoid_activation_for_pondering'):
            p = tf.squeeze(tf.sigmoid(tf.nn.rnn_cell._linear(new_state, 1, True)))


        new_batch_mask = tf.logical_and(tf.less(prob + p,self.one_minus_eps),batch_mask)
        new_float_mask = tf.cast(new_batch_mask, tf.float32)
        prob += p * new_float_mask
        prob_compare += p * tf.cast(batch_mask, tf.float32)

        def use_remainder():

            remainder = tf.constant(1.0, tf.float32,[self.batch_size]) - prob
            remainder_expanded = tf.expand_dims(remainder,1)
            tiled_remainder = tf.tile(remainder_expanded,[1,self.config.inference_size])

            acc_state = (new_state * tiled_remainder) + acc_states
            return acc_state

        def normal():

            p_expanded = tf.expand_dims(p * new_float_mask,1)
            tiled_p = tf.tile(p_expanded,[1,self.config.inference_size])

            acc_state = (new_state * tiled_p) + acc_states
            return acc_state


        counter += tf.constant(1.0,tf.float32,[self.batch_size]) * new_float_mask
        counter_condition = tf.less(counter,self.N)
        condition = tf.reduce_any(tf.logical_and(new_batch_mask,counter_condition))

        acc_state = tf.cond(condition, normal, use_remainder)

        return (new_batch_mask, prob_compare,prob,counter, new_state, premise, hypothesis, acc_state)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号