tflearn_seq2seq.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:tflearn_seq2seq 作者: ichuang 项目源码 文件源码
def test_main3():
    '''
    Integration test - training then prediction: attention model
    '''
    import tempfile
    wfn = "tmp_weights.tfl"
    if os.path.exists(wfn):
        os.unlink(wfn)
    arglist = "-e 2 -o tmp_weights.tfl -v -v -v -v -m embedding_attention train 5000"
    arglist = arglist.split(' ')
    tf.reset_default_graph()
    ts2s = CommandLine(arglist=arglist)
    assert os.path.exists(wfn)

    arglist = "-i tmp_weights.tfl -v -v -v -v -m embedding_attention predict 1 2 3 4 5 6 7 8 9 0" 
    arglist = arglist.split(' ')
    tf.reset_default_graph()
    ts2s = CommandLine(arglist=arglist)
    assert len(ts2s.prediction_results[0][0])==10

#-----------------------------------------------------------------------------
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号