models.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:malmomo 作者: matpalm 项目源码 文件源码
def train(self, batch):
    flip_horizontally = np.random.random() < 0.5

    if VERBOSE_DEBUG:
      print "batch.action"
      print batch.action.T
      print "batch.reward", batch.reward.T
      print "batch.terminal_mask", batch.terminal_mask.T
      print "flip_horizontally", flip_horizontally
      print "weights", batch.weight.T
      values = tf.get_default_session().run([self._l_values, self.value_net.value,
                                             self.advantage, self.target_value_net.value,
                                             self.print_gradient_norms],
        feed_dict={self.input_state: batch.state_1,
                   self.input_action: batch.action,
                   self.reward: batch.reward,
                   self.terminal_mask: batch.terminal_mask,
                   self.input_state_2: batch.state_2,
                   self.importance_weight: batch.weight,
                   base_network.IS_TRAINING: True,
                   base_network.FLIP_HORIZONTALLY: flip_horizontally})
      values = [np.squeeze(v) for v in values]
      print "_l_values", values[0].T
      print "value_net.value        ", values[1].T
      print "advantage              ", values[2].T
      print "target_value_net.value ", values[3].T

    _, _, l = tf.get_default_session().run([self.check_numerics, self.train_op,
                                            self.loss],
      feed_dict={self.input_state: batch.state_1,
                 self.input_action: batch.action,
                 self.reward: batch.reward,
                 self.terminal_mask: batch.terminal_mask,
                 self.input_state_2: batch.state_2,
                 self.importance_weight: batch.weight,
                 base_network.IS_TRAINING: True,
                 base_network.FLIP_HORIZONTALLY: flip_horizontally})
    return l
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号