def do_inference_steps(self, initial_state, premise, hypothesis):
self.one_minus_eps = tf.constant(1.0 - self.config.eps, tf.float32,[self.batch_size])
self.N = tf.constant(self.config.max_computation, tf.float32,[self.batch_size])
prob = tf.constant(0.0,tf.float32,[self.batch_size], name="prob")
prob_compare = tf.constant(0.0,tf.float32,[self.batch_size], name="prob_compare")
counter = tf.constant(0.0, tf.float32,[self.batch_size], name="counter")
acc_states = tf.zeros_like(initial_state, tf.float32, name="state_accumulator")
batch_mask = tf.constant(True, tf.bool,[self.batch_size])
# While loop stops when this predicate is FALSE.
# Ie all (probability < 1-eps AND counter < N) are false.
pred = lambda batch_mask,prob_compare,prob,\
counter,state,premise, hypothesis ,acc_state:\
tf.reduce_any(
tf.logical_and(
tf.less(prob_compare,self.one_minus_eps),
tf.less(counter,self.N)))
# only stop if all of the batch have passed either threshold
# Do while loop iterations until predicate above is false.
_,_,remainders,iterations,_,_,_,state = \
tf.while_loop(pred,self.inference_step,
[batch_mask,prob_compare,prob,
counter,initial_state,premise, hypothesis, acc_states])
return state, remainders, iterations
评论列表
文章目录