def test(self, dataset, subset='test', name='Test'):
global g_args
train_writer = tf.summary.FileWriter(
os.path.join(hparams.SUMMARY_DIR,
str(datetime.datetime.now().strftime("%m%d_%H%M%S")) + ' ' + hparams.SUMMARY_TITLE), g_sess.graph)
cli_report = {}
for data_pt in dataset.epoch(
subset, hparams.BATCH_SIZE * hparams.MAX_N_SIGNAL):
# note: this disables dropout during test
to_feed = dict(
zip(self.train_feed_keys, (
np.reshape(data_pt[0], [hparams.BATCH_SIZE, hparams.MAX_N_SIGNAL, -1, hparams.FEATURE_SIZE]),
1.)))
step_summary, step_fetch = g_sess.run(
self.valid_fetches, to_feed)[:2]
train_writer.add_summary(step_summary)
stdout.write('.')
stdout.flush()
_dict_add(cli_report, step_fetch)
stdout.write(name + ': %s\n' % (
_dict_format(cli_report)))
评论列表
文章目录