language_model_test.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:lm 作者: rafaljozefowicz 项目源码 文件源码
def test_lm(self):
        hps = get_test_hparams()

        with tf.variable_scope("model"):
            model = LM(hps)

        with self.test_session() as sess:
            tf.initialize_all_variables().run()
            tf.initialize_local_variables().run()

            loss = 1e5
            for i in range(50):
                x, y, w = simple_data_generator(hps.batch_size, hps.num_steps)
                loss, _ = sess.run([model.loss, model.train_op], {model.x: x, model.y: y, model.w: w})
                print("%d: %.3f %.3f" % (i, loss, np.exp(loss)))
                if np.isnan(loss):
                    print("NaN detected")
                    break

            self.assertLess(loss, 1.0)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号