def test_train_predict3():
'''
Test that a model trained on sequencees of one length can be used for predictions on other sequence lengths
'''
import tempfile
sp = SequencePattern("sorted", in_seq_len=10, out_seq_len=10)
tempdir = tempfile.mkdtemp()
ts2s = TFLearnSeq2Seq(sp, seq2seq_model="embedding_attention", data_dir=tempdir, name="attention")
tf.reset_default_graph()
ts2s.train(num_epochs=1, num_points=1000, weights_output_fn=1, weights_input_fn=0)
assert os.path.exists(ts2s.weights_output_fn)
tf.reset_default_graph()
sp = SequencePattern("sorted", in_seq_len=20, out_seq_len=8)
tf.reset_default_graph()
ts2s = TFLearnSeq2Seq(sp, seq2seq_model="embedding_attention", data_dir="DATA", name="attention", verbose=1)
x = np.random.randint(0, 9, 20)
prediction, y = ts2s.predict(x, weights_input_fn=1)
assert len(prediction==8)
os.system("rm -rf %s" % tempdir)
评论列表
文章目录