def test_eos_masking(self):
probs = tf.constant([[-.2, -.2, -.2, -.2, -.2], [-.3, -.3, -.3, 3, 0],
[5, 6, 0, 0, 0]])
eos_token = 0
previously_finished = tf.constant([0, 1, 0], dtype=tf.float32)
masked = beam_search.mask_probs(probs, eos_token, previously_finished)
with self.test_session() as sess:
probs = sess.run(probs)
masked = sess.run(masked)
np.testing.assert_array_equal(probs[0], masked[0])
np.testing.assert_array_equal(probs[2], masked[2])
np.testing.assert_equal(masked[1][0], 0)
np.testing.assert_approx_equal(masked[1][1], np.finfo('float32').min)
np.testing.assert_approx_equal(masked[1][2], np.finfo('float32').min)
np.testing.assert_approx_equal(masked[1][3], np.finfo('float32').min)
np.testing.assert_approx_equal(masked[1][4], np.finfo('float32').min)
评论列表
文章目录