main.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:vae-npvc 作者: JeremyCCHsu 项目源码 文件源码
def main():
    ''' NOTE: The input is rescaled to [-1, 1] '''

    dirs = validate_log_dirs(args)
    tf.gfile.MakeDirs(dirs['logdir'])

    with open(args.architecture) as f:
        arch = json.load(f)

    with open(os.path.join(dirs['logdir'], args.architecture), 'w') as f:
        json.dump(arch, f, indent=4)

    normalizer = Tanhize(
        xmax=np.fromfile('./etc/xmax.npf'),
        xmin=np.fromfile('./etc/xmin.npf'),
    )

    image, label = read(
        file_pattern=arch['training']['datadir'],
        batch_size=arch['training']['batch_size'],
        capacity=2048,
        min_after_dequeue=1024,
        normalizer=normalizer,
    )

    machine = MODEL(arch)

    loss = machine.loss(image, label)
    trainer = TRAINER(loss, arch, args, dirs)
    trainer.train(nIter=arch['training']['max_iter'], machine=machine)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号