plan_test.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:fold 作者: tensorflow 项目源码 文件源码
def create_plan(self, loom_input_tensor):
    p = plan.TrainPlan()
    foo = tf.get_variable('foo', [], tf.float32, tf.constant_initializer(12))
    p.compiler = block_compiler.Compiler.create(
        blocks.Scalar() >> blocks.Function(lambda x: x * foo),
        loom_input_tensor=loom_input_tensor)
    p.losses['foo'] = p.compiler.output_tensors[0]
    p.finalize_stats()
    p.train_op = tf.train.GradientDescentOptimizer(1.0).minimize(
        p.loss_total, global_step=p.global_step)
    p.logdir = self.get_temp_dir()
    p.dev_examples = [2]
    p.is_chief_trainer = True
    p.batch_size = 2
    p.epochs = 3
    p.print_file = six.StringIO()
    return p
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号