def valid(self, batch_size = 128, weights_file = None):
if weights_file is not None:
self.saver.restore(self.sess, weights_file)
data_size = self.x_test.shape[0]
num_batches = int(data_size/batch_size)
acc_vals = []
permute_idx = np.random.permutation(np.arange(data_size))
for b in tqdm(np.arange(num_batches)):
x_val = self.x_test[permute_idx[b*batch_size:(b+1)*batch_size]]
y_val = self.y_test[permute_idx[b*batch_size:(b+1)*batch_size]]
acc_val = self.sess.run(self.accuracy,
feed_dict = {self.images:x_val, self.labels:y_val})
acc_vals.append(acc_val)
print('validation accuracy : {}'.format(np.mean(acc_vals)))
评论列表
文章目录