tflearn_seq2seq.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:tflearn_seq2seq 作者: ichuang 项目源码 文件源码
def test_train_predict2():
    '''
    Test that the embedding_attention model works, with saving and loading of weights
    '''
    import tempfile
    sp = SequencePattern()
    tempdir = tempfile.mkdtemp()
    ts2s = TFLearnSeq2Seq(sp, seq2seq_model="embedding_attention", data_dir=tempdir, name="attention")
    tf.reset_default_graph()
    ts2s.train(num_epochs=1, num_points=1000, weights_output_fn=1, weights_input_fn=0)
    assert os.path.exists(ts2s.weights_output_fn)

    tf.reset_default_graph()
    ts2s = TFLearnSeq2Seq(sp, seq2seq_model="embedding_attention", data_dir="DATA", name="attention", verbose=1)
    prediction, y = ts2s.predict(Xin=range(10), weights_input_fn=1)
    assert len(prediction==10)

    os.system("rm -rf %s" % tempdir)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号