beam_search_test.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:conv_seq2seq 作者: tobyyouup 项目源码 文件源码
def test_eos_masking(self):
    probs = tf.constant([[-.2, -.2, -.2, -.2, -.2], [-.3, -.3, -.3, 3, 0],
                         [5, 6, 0, 0, 0]])
    eos_token = 0
    previously_finished = tf.constant([0, 1, 0], dtype=tf.float32)
    masked = beam_search.mask_probs(probs, eos_token, previously_finished)

    with self.test_session() as sess:
      probs = sess.run(probs)
      masked = sess.run(masked)

      np.testing.assert_array_equal(probs[0], masked[0])
      np.testing.assert_array_equal(probs[2], masked[2])
      np.testing.assert_equal(masked[1][0], 0)
      np.testing.assert_approx_equal(masked[1][1], np.finfo('float32').min)
      np.testing.assert_approx_equal(masked[1][2], np.finfo('float32').min)
      np.testing.assert_approx_equal(masked[1][3], np.finfo('float32').min)
      np.testing.assert_approx_equal(masked[1][4], np.finfo('float32').min)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号