def do_act_steps(self, premise, hypothesis):
self.rep_size = premise.get_shape()[-1].value
self.one_minus_eps = tf.constant(1.0 - self.config.eps, tf.float32,[self.batch_size])
self.N = tf.constant(self.config.max_computation, tf.float32,[self.batch_size])
prob = tf.constant(0.0,tf.float32,[self.batch_size], name="prob")
prob_compare = tf.constant(0.0,tf.float32,[self.batch_size], name="prob_compare")
counter = tf.constant(0.0, tf.float32,[self.batch_size], name="counter")
initial_state = tf.zeros([self.batch_size, 2*self.rep_size], tf.float32, name="state")
i = tf.constant(0, tf.int32, name="index")
acc_states = tf.zeros_like(initial_state, tf.float32, name="state_accumulator")
batch_mask = tf.constant(True, tf.bool,[self.batch_size])
# Tensor arrays to collect information about the run:
array_probs = tf.TensorArray(tf.float32,0, dynamic_size=True)
premise_attention = tf.TensorArray(tf.float32,0, dynamic_size=True)
hypothesis_attention = tf.TensorArray(tf.float32,0, dynamic_size=True)
# While loop stops when this predicate is FALSE.
# Ie all (probability < 1-eps AND counter < N) are false.
pred = lambda i ,array_probs, premise_attention, hypothesis_attention, batch_mask,prob_compare,prob,\
counter,state,premise, hypothesis ,acc_state:\
tf.reduce_any(
tf.logical_and(
tf.less(prob_compare,self.one_minus_eps),
tf.less(counter,self.N)))
# only stop if all of the batch have passed either threshold
# Do while loop iterations until predicate above is false.
i,array_probs,premise_attention,hypothesis_attention,_,_,remainders,iterations,_,_,_,state = \
tf.while_loop(pred,self.inference_step,
[i,array_probs, premise_attention, hypothesis_attention,
batch_mask,prob_compare,prob,
counter,initial_state,premise, hypothesis, acc_states])
self.ACTPROB = array_probs.pack()
self.ACTPREMISEATTN = premise_attention.pack()
self.ACTHYPOTHESISATTN = hypothesis_attention.pack()
return state, remainders, iterations
ACTAttnAnalysisModel.py 文件源码
python
阅读 19
收藏 0
点赞 0
评论 0
评论列表
文章目录