def __call__(self, inputs, state, timestep = 0, scope=None):
with vs.variable_scope(scope or type(self).__name__):
# define within cell constants/ counters used to control while loop for ACTStep
prob = tf.constant(0.0,tf.float32,[self.batch_size], name="prob")
prob_compare = tf.constant(0.0,tf.float32,[self.batch_size], name="prob_compare")
counter = tf.constant(0.0, tf.float32,[self.batch_size], name="counter")
acc_outputs = tf.zeros_like(state, tf.float32, name="output_accumulator")
acc_states = tf.zeros_like(state, tf.float32, name="state_accumulator")
batch_mask = tf.constant(True, tf.bool,[self.batch_size])
# While loop stops when this predicate is FALSE.
# Ie all (probability < 1-eps AND counter < N) are false.
pred = lambda batch_mask,prob_compare,prob,\
counter,state,inputs,acc_output,acc_state:\
tf.reduce_any(
tf.logical_and(
tf.less(prob_compare,self.one_minus_eps),
tf.less(counter,self.N)))
# only stop if all of the batch have passed either threshold
# Do while loop iterations until predicate above is false.
_,_,remainders,iterations,_,_,output,next_state = \
control_flow_ops.while_loop(pred,self.ACTStep,
[batch_mask,prob_compare,prob,
counter,state,inputs, acc_outputs, acc_states])
#accumulate remainder and N values
self.ACT_remainder.append(tf.reduce_mean(1 - remainders))
self.ACT_iterations.append(tf.reduce_mean(iterations))
return output, next_state
评论列表
文章目录