cnn_visualization.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:NumpyDL 作者: oujago 项目源码 文件源码
def train():
    for i in range(20000):
        randomint = randint(0, 10000 - batchsize - 1)
        trainingData = batch["data"][randomint:batchsize + randomint]
        rawlabel = batch["labels"][randomint:batchsize + randomint]
        trainingLabel = np.zeros((batchsize, 10))
        trainingLabel[np.arange(batchsize), rawlabel] = 1
        trainingData = trainingData / 255.0
        trainingData = np.reshape(trainingData, [-1, 3, 32, 32])
        trainingData = np.swapaxes(trainingData, 1, 3)

        if i % 10 == 0:
            train_accuracy = accuracy.eval(feed_dict={
                img: validationData, lbl: validationLabel, keepProb: 1.0})
            print("step %d, training accuracy %g" % (i, train_accuracy))

            if i % 50 == 0:
                saver.save(sess, os.getcwd() + "/training/train", global_step=i)

        optimizer.run(feed_dict={img: trainingData, lbl: trainingLabel, keepProb: 0.5})
        print(i)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号