def inference_step(self,batch_mask, prob_compare,prob,counter, state, premise, hypothesis, acc_states):
if self.config.keep_prob < 1.0 and self.is_training:
premise = tf.nn.dropout(premise, self.config.keep_prob)
hypothesis = tf.nn.dropout(hypothesis,self.config.keep_prob)
hyp_attn = self.attention(state, hypothesis, "hyp_attn")
state_for_premise = tf.concat(1, [state, hyp_attn])
prem_attn = self.attention(state_for_premise, premise, "prem_attn")
state_for_gates = tf.concat(1, [state, hyp_attn ,prem_attn, prem_attn * hyp_attn])
hyp_gate = self.gate_mechanism(state_for_gates, "hyp_gate")
prem_gate = self.gate_mechanism(state_for_gates, "prem_gate")
input = tf.concat(1, [hyp_gate * hyp_attn, prem_gate * prem_attn])
output, new_state = self.inference_cell(input,state)
with tf.variable_scope('sigmoid_activation_for_pondering'):
p = tf.squeeze(tf.sigmoid(tf.nn.rnn_cell._linear(new_state, 1, True)))
new_batch_mask = tf.logical_and(tf.less(prob + p,self.one_minus_eps),batch_mask)
new_float_mask = tf.cast(new_batch_mask, tf.float32)
prob += p * new_float_mask
prob_compare += p * tf.cast(batch_mask, tf.float32)
def use_remainder():
remainder = tf.constant(1.0, tf.float32,[self.batch_size]) - prob
remainder_expanded = tf.expand_dims(remainder,1)
tiled_remainder = tf.tile(remainder_expanded,[1,self.config.inference_size])
acc_state = (new_state * tiled_remainder) + acc_states
return acc_state
def normal():
p_expanded = tf.expand_dims(p * new_float_mask,1)
tiled_p = tf.tile(p_expanded,[1,self.config.inference_size])
acc_state = (new_state * tiled_p) + acc_states
return acc_state
counter += tf.constant(1.0,tf.float32,[self.batch_size]) * new_float_mask
counter_condition = tf.less(counter,self.N)
condition = tf.reduce_any(tf.logical_and(new_batch_mask,counter_condition))
acc_state = tf.cond(condition, normal, use_remainder)
return (new_batch_mask, prob_compare,prob,counter, new_state, premise, hypothesis, acc_state)
评论列表
文章目录