def _test_metric_spec(self, metric_spec, hyps, refs, expected_scores):
"""Tests a MetricSpec"""
predictions = {"predicted_tokens": tf.placeholder(dtype=tf.string)}
labels = {"target_tokens": tf.placeholder(dtype=tf.string)}
value, update_op = metric_spec.create_metric_ops(None, labels, predictions)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
scores = []
for hyp, ref in zip(hyps, refs):
hyp = hyp.split(" ")
ref = ref.split(" ")
sess.run(update_op, {
predictions["predicted_tokens"]: [hyp],
labels["target_tokens"]: [ref]
})
scores.append(sess.run(value))
for score, expected in zip(scores, expected_scores):
np.testing.assert_almost_equal(score, expected, decimal=2)
np.testing.assert_almost_equal(score, expected, decimal=2)
评论列表
文章目录