def __call__(self, enc_input, dec_input_indices, valid_indices, left_indices, right_indices, values, valid_masks=None):
batch_size = tf.shape(enc_input)[0]
# forward computation graph
with tf.variable_scope(self.scope):
# encoder output
enc_memory, enc_final_state_fw, _ = self.encoder(enc_input)
# decoder
dec_hiddens, dec_actions, dec_act_logps = self.decoder(
enc_memory, dec_input_indices,
valid_indices, left_indices, right_indices,
valid_masks, init_state=enc_final_state_fw)
# cost
costs = []
update_ops = []
for step_idx, (act_logp, value, baseline) in enumerate(zip(dec_act_logps, values, self.baselines)):
# costs.append(-tf.reduce_mean(act_logp * (value - baseline)))
new_baseline = self.bl_ratio * baseline + (1-self.bl_ratio) * tf.reduce_mean(value)
costs.append(-tf.reduce_mean(act_logp * value))
update_ops.append(tf.assign(baseline, new_baseline))
# gradient computation graph
self.params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.scope)
train_ops = []
for limit in self.buckets:
print '0 ~ %d' % (limit-1)
grad_params = tf.gradients(tf.reduce_sum(tf.pack(costs[:limit])), self.params)
if self.max_grad_norm is not None:
clipped_gradients, norm = tf.clip_by_global_norm(grad_params, self.max_grad_norm)
else:
clipped_gradients = grad_params
train_op = self.optimizer.apply_gradients(
zip(clipped_gradients, self.params))
with tf.control_dependencies([train_op] + update_ops[:limit]):
# train_ops.append(tf.Print(tf.constant(1.), [norm]))
train_ops.append(tf.constant(1.))
return dec_hiddens, dec_actions, train_ops
#### test script
评论列表
文章目录