train_val.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def __init__(self, conf, images=None, scores=None, goal_pos=None, desig_pos=None):
        batchsize = int(conf['batch_size'])
        if goal_pos is None:
            self.goal_pos = goal_pos= tf.placeholder(tf.float32, name='goalpos', shape=(batchsize, 2))
        if desig_pos is None:
            self.desig_pos = desig_pos =  tf.placeholder(tf.float32, name='desig_pos_pl', shape=(batchsize, 2))
        if scores is None:
            self.scores = scores = tf.placeholder(tf.float32, name='score_pl', shape=(batchsize, 1))
        if images is None:
            self.images = images = tf.placeholder(tf.float32, name='images_pl', shape=(batchsize, 1, 64,64,3))

        self.prefix = prefix = tf.placeholder(tf.string, [])

        from value_model import construct_model

        summaries = []
        inf_scores = construct_model(conf, images, goal_pos, desig_pos)
        self.inf_scores = inf_scores
        self.loss = loss = mean_squared_error(inf_scores, scores)

        summaries.append(tf.scalar_summary(prefix + '_loss', loss))

        self.lr = tf.placeholder_with_default(conf['learning_rate'], ())

        self.train_op = tf.train.AdamOptimizer(self.lr).minimize(loss)
        self.summ_op = tf.merge_summary(summaries)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号