mnist_test.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:mist-rnns 作者: rdipietro 项目源码 文件源码
def main():
  """ Test an RNN for sequential (possibly permuted) MNIST recognition. """

  args, params_str, layer_kwargs = parse_args()

  outs = mnist.load_split(args.data_dir, val=False, permute=args.permute, normalize=True, seed=0)
  _, _, test_images, test_labels = outs

  # Flatten the images.
  test_inputs = test_images.reshape([len(test_images), -1, INPUT_SIZE])

  # Align sequence-level labels with the appropriate time steps by padding with NaNs,
  # and to do so, first convert the labels to floats.
  length = test_inputs.shape[1]
  pad = lambda x: np.pad(x, [[0, 0], [length - 1, 0], [0, 0]], mode='constant', constant_values=np.nan)
  test_labels = pad(test_labels.reshape([-1, 1, 1]).astype(np.float))

  test_batches = utils.full_bptt_batch_generator(test_inputs, test_labels, TEST_BATCH_SIZE, num_epochs=1,
                                                 shuffle=False)

  model = models.RNNClassificationModel(args.layer_type, INPUT_SIZE, TARGET_SIZE, args.num_hidden_units,
                                        args.activation_type, **layer_kwargs)

  def _error_rate(valid_predictions, valid_targets):
    incorrect_mask = tf.logical_not(tf.equal(tf.argmax(valid_predictions, 1), tf.argmax(valid_targets, 1)))
    return tf.reduce_mean(tf.to_float(incorrect_mask))
  model.error_rate = _error_rate(model.valid_predictions, model.valid_targets)

  config = tf.ConfigProto()
  config.gpu_options.allow_growth = False
  sess = tf.Session(config=config)

  saver = tf.train.Saver()
  saver.restore(sess, os.path.join(args.results_dir, 'model.ckpt'))

  error_rates = []
  for batch_inputs, batch_labels in test_batches:

    batch_targets = utils.one_hot(np.squeeze(batch_labels, 2), TARGET_SIZE)
    valid_predictions, valid_targets, batch_error_rates = sess.run(
      [model.valid_predictions, model.valid_targets, model.error_rate],
      feed_dict={model.inputs: batch_inputs,
                 model.targets: batch_targets}
    )
    error_rates.append(batch_error_rates)

  error_rate = np.mean(error_rates, dtype=np.float)
  print('%f' % error_rate)
  with open(os.path.join(args.results_dir, 'test_result.txt'), 'w') as f:
    print('%f' % error_rate, file=f)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号