models_test.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:seq2seq 作者: google 项目源码 文件源码
def test_train(self):
    model, fetches_ = self._test_pipeline(tf.contrib.learn.ModeKeys.TRAIN)
    predictions_, loss_, _ = fetches_

    target_len = self.sequence_length + 10 + 2
    max_decode_length = model.params["target.max_seq_len"]
    expected_decode_len = np.minimum(target_len, max_decode_length)

    np.testing.assert_array_equal(predictions_["logits"].shape, [
        self.batch_size, expected_decode_len - 1,
        model.target_vocab_info.total_size
    ])
    np.testing.assert_array_equal(predictions_["losses"].shape,
                                  [self.batch_size, expected_decode_len - 1])
    np.testing.assert_array_equal(predictions_["predicted_ids"].shape,
                                  [self.batch_size, expected_decode_len - 1])
    self.assertFalse(np.isnan(loss_))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号