train.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:DeepWorks 作者: daigo0927 项目源码 文件源码
def valid(self, batch_size = 128, weights_file = None):

        if weights_file is not None:
            self.saver.restore(self.sess, weights_file)

        data_size = self.x_test.shape[0]
        num_batches = int(data_size/batch_size)

        acc_vals = []
        permute_idx = np.random.permutation(np.arange(data_size))
        for b in tqdm(np.arange(num_batches)):
            x_val = self.x_test[permute_idx[b*batch_size:(b+1)*batch_size]]
            y_val = self.y_test[permute_idx[b*batch_size:(b+1)*batch_size]]

            acc_val = self.sess.run(self.accuracy,
                                    feed_dict = {self.images:x_val, self.labels:y_val})
            acc_vals.append(acc_val)

        print('validation accuracy : {}'.format(np.mean(acc_vals)))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号