def __call__(self, inputs, state, scope=None):
with vs.variable_scope(scope or type(self).__name__):
# define within cell constants/ counters used to control while loop for ACTStep
if self.state_is_tuple:
state = array_ops.concat(1, state)
self.batch_size = tf.shape(inputs)[0]
self.one_minus_eps = tf.fill([self.batch_size], tf.constant(1.0 - self.epsilon, dtype=tf.float32))
prob = tf.fill([self.batch_size], tf.constant(0.0, dtype=tf.float32), "prob")
counter = tf.zeros_like(prob, tf.float32, name="counter")
acc_outputs = tf.fill([self.batch_size, self.output_size], 0.0, name='output_accumulator')
acc_states = tf.zeros_like(state, tf.float32, name="state_accumulator")
flag = tf.fill([self.batch_size], True, name="flag")
pred = lambda flag, prob, counter, state, inputs, acc_outputs, acc_states: tf.reduce_any(flag)
_, probs, iterations, _, _, output, next_state = control_flow_ops.while_loop(pred, self.act_step, loop_vars=[flag, prob, counter, state, inputs, acc_outputs, acc_states])
self.ACT_remainder.append(1 - probs)
self.ACT_iterations.append(iterations)
if self.state_is_tuple:
next_c, next_h = array_ops.split(1, 2, next_state)
next_state = rnn_cell._LSTMStateTuple(next_c, next_h)
return output, next_state
评论列表
文章目录