def train(self, batch):
_, _, l = tf.get_default_session().run([self.check_numerics, self.train_op, self.loss],
feed_dict={self.input_state: batch.state_1,
self.input_action: batch.action,
self.reward: batch.reward,
self.terminal_mask: batch.terminal_mask,
self.input_state_2: batch.state_2,
base_network.IS_TRAINING: True})
return l
评论列表
文章目录