def main():
""" Test an RNN trained for TIMIT phoneme recognition. """
args, params_str, layer_kwargs = parse_args()
_, _, test_inputs, test_labels = timitphonemerec.load_split(args.data_dir, val=False,
mfcc=True, normalize=True)
# Input seqs have shape [length, INPUT_SIZE]. Label seqs are int8 arrays with shape [length],
# but need to have shape [length, 1] for the batch generator.
test_labels = [seq[:, np.newaxis] for seq in test_labels]
test_batches = utils.full_bptt_batch_generator(test_inputs, test_labels, TEST_BATCH_SIZE,
num_epochs=1, shuffle=False)
model = models.RNNClassificationModel(args.layer_type, INPUT_SIZE, TARGET_SIZE, args.num_hidden_units,
args.activation_type, **layer_kwargs)
def _error_rate(valid_predictions, valid_targets):
incorrect_mask = tf.logical_not(tf.equal(tf.argmax(valid_predictions, 1), tf.argmax(valid_targets, 1)))
return tf.reduce_mean(tf.to_float(incorrect_mask))
model.error_rate = _error_rate(model.valid_predictions, model.valid_targets)
config = tf.ConfigProto()
config.gpu_options.allow_growth = False
sess = tf.Session(config=config)
saver = tf.train.Saver()
saver.restore(sess, os.path.join(args.results_dir, 'model.ckpt'))
batch_inputs, batch_labels = next(test_batches)
batch_targets = utils.one_hot(np.squeeze(batch_labels, 2), TARGET_SIZE)
valid_predictions, valid_targets, error_rate = sess.run(
[model.valid_predictions, model.valid_targets, model.error_rate],
feed_dict={model.inputs: batch_inputs,
model.targets: batch_targets}
)
print('%f' % error_rate)
with open(os.path.join(args.results_dir, 'test_result.txt'), 'w') as f:
print('%f' % error_rate, file=f)
评论列表
文章目录