test_graph_cnn.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:chainer-graph-cnn 作者: pfnet-research 项目源码 文件源码
def check_train(self, gpu):
        outdir = tempfile.mkdtemp()
        print("outdir: {}".format(outdir))

        n_classes = 2
        batch_size = 32

        devices = {'main': gpu}

        A = np.array([
            [0, 1, 1, 0],
            [1, 0, 0, 1],
            [1, 0, 0, 0],
            [0, 1, 0, 0],
        ]).astype(np.float32)
        model = graph_cnn.GraphCNN(A, n_out=n_classes)

        optimizer = optimizers.Adam(alpha=1e-4)
        optimizer.setup(model)
        train_dataset = EasyDataset(train=True, n_classes=n_classes)
        train_iter = chainer.iterators.MultiprocessIterator(
            train_dataset, batch_size)
        updater = ParallelUpdater(train_iter, optimizer, devices=devices)
        trainer = chainer.training.Trainer(updater, (10, 'epoch'), out=outdir)
        trainer.extend(extensions.LogReport(trigger=(1, 'epoch')))
        trainer.extend(extensions.PrintReport(
            ['epoch', 'iteration', 'main/loss', 'main/accuracy']))
        trainer.extend(extensions.ProgressBar())
        trainer.run()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号