gan.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:vae-npvc 作者: JeremyCCHsu 项目源码 文件源码
def train(self, nIter, machine=None, summary_op=None):
        Xh = self._validate(machine=machine, n=10)

        run_metadata = tf.RunMetadata()

        # summary_op = tf.summary.merge_all()

        sv = tf.train.Supervisor(
            logdir=self.dirs['logdir'],
            # summary_writer=summary_writer,
            # summary_op=None,
            # is_chief=True,
            # save_model_secs=600,
            global_step=self.opt['global_step'])


        # sess_config = configure_gpu_settings(args.gpu_cfg)
        sess_config = tf.ConfigProto(
            allow_soft_placement=True,
            gpu_options=tf.GPUOptions(allow_growth=True))

        with sv.managed_session(config=sess_config) as sess:
            sv.loop(60, self._refresh_status, (sess,))
            for step in range(self.arch['training']['max_iter']):
                if sv.should_stop():
                    break

                # main loop
                for _ in range(self.arch['training']['nIterD']):
                    sess.run(self.opt['d'])
                sess.run(self.opt['g'])

                # output img                
                if step % 1000 == 0:
                    xh = sess.run(Xh)
                    with tf.gfile.GFile(
                        os.path.join(
                            self.dirs['logdir'],
                            'img-anime-{:03d}k.png'.format(step // 1000),
                        ),
                        mode='wb',
                    ) as fp:
                        fp.write(xh)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号