def inference_step(self,batch_mask, prob_compare,prob,counter, state, premise, hypothesis, acc_states):
if self.config.keep_prob < 1.0 and self.is_training:
premise = tf.nn.dropout(premise, self.config.keep_prob)
hypothesis = tf.nn.dropout(hypothesis,self.config.keep_prob)
hyp_attn = self.attention(state, hypothesis, "hyp_attn")
state_for_premise = tf.concat(1, [state, hyp_attn])
prem_attn = self.attention(state_for_premise, premise, "prem_attn")
new_state = tf.concat(1, [hyp_attn ,prem_attn])
with tf.variable_scope('sigmoid_activation_for_pondering'):
p = tf.squeeze(tf.sigmoid(tf.nn.rnn_cell._linear(new_state, 1, True)))
new_batch_mask = tf.logical_and(tf.less(prob + p,self.one_minus_eps),batch_mask)
new_float_mask = tf.cast(new_batch_mask, tf.float32)
prob += p * new_float_mask
prob_compare += p * tf.cast(batch_mask, tf.float32)
def use_remainder():
remainder = tf.constant(1.0, tf.float32,[self.batch_size]) - prob
remainder_expanded = tf.expand_dims(remainder,1)
tiled_remainder = tf.tile(remainder_expanded,[1,2*self.rep_size])
acc_state = (new_state * tiled_remainder) + acc_states
return acc_state
def normal():
p_expanded = tf.expand_dims(p * new_float_mask,1)
tiled_p = tf.tile(p_expanded,[1,2*self.rep_size])
acc_state = (new_state * tiled_p) + acc_states
return acc_state
counter += tf.constant(1.0,tf.float32,[self.batch_size]) * new_float_mask
counter_condition = tf.less(counter,self.N)
condition = tf.reduce_any(tf.logical_and(new_batch_mask,counter_condition))
acc_state = tf.cond(condition, normal, use_remainder)
return (new_batch_mask, prob_compare,prob,counter, new_state, premise, hypothesis, acc_state)
评论列表
文章目录