tflearn_seq2seq.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:tflearn_seq2seq 作者: ichuang 项目源码 文件源码
def test_train_predict3():
    '''
    Test that a model trained on sequencees of one length can be used for predictions on other sequence lengths
    '''
    import tempfile
    sp = SequencePattern("sorted", in_seq_len=10, out_seq_len=10)
    tempdir = tempfile.mkdtemp()
    ts2s = TFLearnSeq2Seq(sp, seq2seq_model="embedding_attention", data_dir=tempdir, name="attention")
    tf.reset_default_graph()
    ts2s.train(num_epochs=1, num_points=1000, weights_output_fn=1, weights_input_fn=0)
    assert os.path.exists(ts2s.weights_output_fn)

    tf.reset_default_graph()
    sp = SequencePattern("sorted", in_seq_len=20, out_seq_len=8)
    tf.reset_default_graph()
    ts2s = TFLearnSeq2Seq(sp, seq2seq_model="embedding_attention", data_dir="DATA", name="attention", verbose=1)
    x = np.random.randint(0, 9, 20)
    prediction, y = ts2s.predict(x, weights_input_fn=1)
    assert len(prediction==8)

    os.system("rm -rf %s" % tempdir)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号