basic_train.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:master-thesis 作者: AndreasMadsen 项目源码 文件源码
def basic_train(loss_op, update_op,
                profile=0, save_dir='asset/unamed',
                **kwargs):
    profile_state = _ShouldProfile(profile)

    @stf.sg_train_func
    def train_func(sess, arg):
        profile_state.increment()

        if profile_state.should_profile():
            options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
            run_metadata = tf.RunMetadata()
        else:
            options = None
            run_metadata = None

        loss = sess.run([loss_op] + update_op,
                        options=options,
                        run_metadata=run_metadata)[0]

        if profile_state.should_profile():
            tl = tf_timeline.Timeline(run_metadata.step_stats)
            ctf = tl.generate_chrome_trace_format()
            with open(path.join(save_dir, 'timeline.json'), 'w') as fd:
                print(ctf, file=fd)

        return loss

    # run train function
    train_func(save_dir=save_dir, **kwargs)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号