encode_stream.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:YATS2S 作者: Scitator 项目源码 文件源码
def encoder_pipeline(
        sess, data_stream, token2id, embedding_size,
        encoder_size, bidirectional, decoder_size, attention,
        checkpoint_path,
        batch_size=32, use_norm=False, lstm_connection=1):

    encoder_args = {
        "cell": rnn.LSTMCell(encoder_size),
        "bidirectional": bidirectional,
    }

    # @TODO: rewrite save-load for no-decoder usage
    decoder_args = {
        "cell": rnn.LSTMCell(decoder_size),
        "attention": attention,
    }
    spec_symbols_bias = 3
    model = create_model(
        len(token2id) + spec_symbols_bias, embedding_size, encoder_args, decoder_args)

    saver = tf.train.Saver()
    saver.restore(sess, checkpoint_path)

    for embedding_matr in rnn_encoder_encode_stream(
            sess, data_stream, model, batch_size, use_norm, lstm_connection=lstm_connection):
        yield embedding_matr
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号