def encoder_pipeline(
sess, data_stream, token2id, embedding_size,
encoder_size, bidirectional, decoder_size, attention,
checkpoint_path,
batch_size=32, use_norm=False, lstm_connection=1):
encoder_args = {
"cell": rnn.LSTMCell(encoder_size),
"bidirectional": bidirectional,
}
# @TODO: rewrite save-load for no-decoder usage
decoder_args = {
"cell": rnn.LSTMCell(decoder_size),
"attention": attention,
}
spec_symbols_bias = 3
model = create_model(
len(token2id) + spec_symbols_bias, embedding_size, encoder_args, decoder_args)
saver = tf.train.Saver()
saver.restore(sess, checkpoint_path)
for embedding_matr in rnn_encoder_encode_stream(
sess, data_stream, model, batch_size, use_norm, lstm_connection=lstm_connection):
yield embedding_matr
评论列表
文章目录