main.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:DaNet-Tensorflow 作者: khaotik 项目源码 文件源码
def test(self, dataset, subset='test', name='Test'):
        global g_args
        train_writer = tf.summary.FileWriter(
            os.path.join(hparams.SUMMARY_DIR,
                         str(datetime.datetime.now().strftime("%m%d_%H%M%S")) + ' ' + hparams.SUMMARY_TITLE), g_sess.graph)
        cli_report = {}
        for data_pt in dataset.epoch(
                subset, hparams.BATCH_SIZE * hparams.MAX_N_SIGNAL):
            # note: this disables dropout during test
            to_feed = dict(
                zip(self.train_feed_keys, (
                    np.reshape(data_pt[0], [hparams.BATCH_SIZE, hparams.MAX_N_SIGNAL, -1, hparams.FEATURE_SIZE]),
                    1.)))
            step_summary, step_fetch = g_sess.run(
                self.valid_fetches, to_feed)[:2]
            train_writer.add_summary(step_summary)
            stdout.write('.')
            stdout.flush()
            _dict_add(cli_report, step_fetch)
        stdout.write(name + ': %s\n' % (
            _dict_format(cli_report)))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号