validate.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:vae-npvc 作者: JeremyCCHsu 项目源码 文件源码
def main():
    if args.logdir is None:
        raise ValueError('Please specify the logdir file')

    ckpt = get_checkpoint(args.logdir)

    if ckpt is None:
        raise ValueError('No checkpoints in {}'.format(args.logdir))

    with open(os.path.join(args.logdir, 'architecture.json')) as f:
        arch = json.load(f)

    reader = VCC2016TFRManager()
    features = reader.read_whole(args.file_pattern, num_epochs=1)
    x = features['frame']
    y = features['label']
    filename = features['filename']
    y_conv = y * 0 + args.target_id

    net = MLPcVAE(arch=arch, is_training=False)
    z = net.encode(x)
    xh = net.decode(z, y)
    x_conv = net.decode(z, y_conv)

    pre_train_saver = tf.train.Saver()
    def load_pretrain(sess):
        pre_train_saver.restore(sess, ckpt)
    sv = tf.train.Supervisor(init_fn=load_pretrain)
    gpu_options = tf.GPUOptions(allow_growth=True)
    sess_config = tf.ConfigProto(
        allow_soft_placement=True,
        gpu_options=gpu_options)
    with sv.managed_session(config=sess_config) as sess:
        for _ in range(reader.n_files):
            if sv.should_stop():
                break
            fetch_dict = {'x': x, 'xh': xh, 'x_conv': x_conv, 'f': filename}
            results = sess.run(fetch_dict)
            plot_spectra(results)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号