batch.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:segmenter 作者: yanshao9798 项目源码 文件源码
def predict_seq2seq(sess, model, decoding, data, decode_len, dr=None, argmax=True, batch_size=100, ensemble=False, verbose=False):
    num_items = len(data)
    in_len = len(data[0][0])
    input_v = model[:num_items*in_len + decode_len]
    input_v.append(decoding)
    if dr is not None:
        input_v.append(dr)
    predictions = model[num_items*in_len + decode_len:]
    output = []
    samples = zip(*data)
    start_idx = 0
    n_samples = len(samples)
    while start_idx < n_samples:
        if verbose:
            print '%d' % (start_idx * 100 / n_samples) + '%'
        next_batch_input = samples[start_idx:start_idx + batch_size]
        batch_size = len(next_batch_input)
        holders = []
        next_batch_input = zip(*next_batch_input)
        for n_batch in next_batch_input:
            n_batch = np.asarray(n_batch).T
            for b in n_batch:
                holders.append(b)
        for i in range(decode_len):
            holders.append(np.zeros(batch_size, dtype='int32'))
        holders.append(True)
        if dr is not None:
            holders.append(0.0)
        if argmax:
            pre = sess.run(predictions, feed_dict={i: h for i, h in zip(input_v, holders)})
            pre = [np.argmax(pre_t, axis=1) for pre_t in pre]
            pre = np.asarray(pre).T.tolist()
            pre = [np.trim_zeros(pre_t) for pre_t in pre]
            output += pre
        else:
            pre = sess.run(predictions, feed_dict={i: h for i, h in zip(input_v, holders)})
            output += pre
        start_idx += batch_size
    return output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号