def test_lm(self):
hps = get_test_hparams()
with tf.variable_scope("model"):
model = LM(hps)
with self.test_session() as sess:
tf.initialize_all_variables().run()
tf.initialize_local_variables().run()
loss = 1e5
for i in range(50):
x, y, w = simple_data_generator(hps.batch_size, hps.num_steps)
loss, _ = sess.run([model.loss, model.train_op], {model.x: x, model.y: y, model.w: w})
print("%d: %.3f %.3f" % (i, loss, np.exp(loss)))
if np.isnan(loss):
print("NaN detected")
break
self.assertLess(loss, 1.0)
评论列表
文章目录