BoxMNIST.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:Athena 作者: bakhyeonjae 项目源码 文件源码
def execute(self):

        model = MnistModel()

        batch_size = 50
        train_loader = torch.utils.data.DataLoader(datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor()), batch_size=batch_size, shuffle=True)
        test_loader = torch.utils.data.DataLoader(datasets.MNIST('data', train=False, transform=transforms.ToTensor()), batch_size=1000)

        for p in model.parameters():
            print(p.size())

        optimizer = optim.Adam(model.parameters(), lr=0.0001)

        model.train()
        train_loss = []
        train_accu = []
        i = 0
        for epoch in range(15):
            for data, target in train_loader:
                data, target = Variable(data), Variable(target)
                optimizer.zero_grad()
                output = model(data)
                loss = F.nll_loss(output, target)
                loss.backward()    # calc gradients
                train_loss.append(loss.data[0])
                optimizer.step()   # update gradients
                prediction = output.data.max(1)[1]   # first column has actual prob.
                accuracy = prediction.eq(target.data).sum()/batch_size*100
                train_accu.append(accuracy)
                if i % 1000 == 0:
                    print('Train Step: {}\tLoss: {:.3f}\tAccuracy: {:.3f}'.format(i, loss.data[0], accuracy))
                i += 1

        plt.plot(np.arange(len(train_loss)), train_loss)
        plt.plot(np.arange(len(train_accu)), train_accu)

        model.eval()
        correct = 0
        for data, target in test_loader:
            data, target = Variable(data, volatile=True), Variable(target)
            output = model(data)
            prediction = output.data.max(1)[1]
            correct += prediction.eq(target.data).sum()

        print('\nTest set: Accuracy: {:.2f}%'.format(100. * correct / len(test_loader.dataset)))

        self.data = np.random.rand(self.dim[0],self.dim[1])

        for port in self.outPorts:
            port.transferData(self.data)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号