def main():
if args.logdir is None:
raise ValueError('Please specify the logdir file')
ckpt = get_checkpoint(args.logdir)
if ckpt is None:
raise ValueError('No checkpoints in {}'.format(args.logdir))
with open(os.path.join(args.logdir, 'architecture.json')) as f:
arch = json.load(f)
reader = VCC2016TFRManager()
features = reader.read_whole(args.file_pattern, num_epochs=1)
x = features['frame']
y = features['label']
filename = features['filename']
y_conv = y * 0 + args.target_id
net = MLPcVAE(arch=arch, is_training=False)
z = net.encode(x)
xh = net.decode(z, y)
x_conv = net.decode(z, y_conv)
pre_train_saver = tf.train.Saver()
def load_pretrain(sess):
pre_train_saver.restore(sess, ckpt)
sv = tf.train.Supervisor(init_fn=load_pretrain)
gpu_options = tf.GPUOptions(allow_growth=True)
sess_config = tf.ConfigProto(
allow_soft_placement=True,
gpu_options=gpu_options)
with sv.managed_session(config=sess_config) as sess:
for _ in range(reader.n_files):
if sv.should_stop():
break
fetch_dict = {'x': x, 'xh': xh, 'x_conv': x_conv, 'f': filename}
results = sess.run(fetch_dict)
plot_spectra(results)
评论列表
文章目录