def test_main3():
'''
Integration test - training then prediction: attention model
'''
import tempfile
wfn = "tmp_weights.tfl"
if os.path.exists(wfn):
os.unlink(wfn)
arglist = "-e 2 -o tmp_weights.tfl -v -v -v -v -m embedding_attention train 5000"
arglist = arglist.split(' ')
tf.reset_default_graph()
ts2s = CommandLine(arglist=arglist)
assert os.path.exists(wfn)
arglist = "-i tmp_weights.tfl -v -v -v -v -m embedding_attention predict 1 2 3 4 5 6 7 8 9 0"
arglist = arglist.split(' ')
tf.reset_default_graph()
ts2s = CommandLine(arglist=arglist)
assert len(ts2s.prediction_results[0][0])==10
#-----------------------------------------------------------------------------
评论列表
文章目录