def test_train_predict2():
'''
Test that the embedding_attention model works, with saving and loading of weights
'''
import tempfile
sp = SequencePattern()
tempdir = tempfile.mkdtemp()
ts2s = TFLearnSeq2Seq(sp, seq2seq_model="embedding_attention", data_dir=tempdir, name="attention")
tf.reset_default_graph()
ts2s.train(num_epochs=1, num_points=1000, weights_output_fn=1, weights_input_fn=0)
assert os.path.exists(ts2s.weights_output_fn)
tf.reset_default_graph()
ts2s = TFLearnSeq2Seq(sp, seq2seq_model="embedding_attention", data_dir="DATA", name="attention", verbose=1)
prediction, y = ts2s.predict(Xin=range(10), weights_input_fn=1)
assert len(prediction==10)
os.system("rm -rf %s" % tempdir)
评论列表
文章目录