def train(self, batch):
flip_horizontally = np.random.random() < 0.5
if VERBOSE_DEBUG:
print "batch.action"
print batch.action.T
print "batch.reward", batch.reward.T
print "batch.terminal_mask", batch.terminal_mask.T
print "flip_horizontally", flip_horizontally
print "weights", batch.weight.T
values = tf.get_default_session().run([self._l_values, self.value_net.value,
self.advantage, self.target_value_net.value,
self.print_gradient_norms],
feed_dict={self.input_state: batch.state_1,
self.input_action: batch.action,
self.reward: batch.reward,
self.terminal_mask: batch.terminal_mask,
self.input_state_2: batch.state_2,
self.importance_weight: batch.weight,
base_network.IS_TRAINING: True,
base_network.FLIP_HORIZONTALLY: flip_horizontally})
values = [np.squeeze(v) for v in values]
print "_l_values", values[0].T
print "value_net.value ", values[1].T
print "advantage ", values[2].T
print "target_value_net.value ", values[3].T
_, _, l = tf.get_default_session().run([self.check_numerics, self.train_op,
self.loss],
feed_dict={self.input_state: batch.state_1,
self.input_action: batch.action,
self.reward: batch.reward,
self.terminal_mask: batch.terminal_mask,
self.input_state_2: batch.state_2,
self.importance_weight: batch.weight,
base_network.IS_TRAINING: True,
base_network.FLIP_HORIZONTALLY: flip_horizontally})
return l
评论列表
文章目录