io_util.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:aspect_adversarial 作者: yuanzh 项目源码 文件源码
def create_batches(data, batch_size, padding_id, label=True, sort=True, shuffle=True):
    if label:
        for d in data:
            assert d[1] != -1
    if sort:
        data = sorted(data, key=lambda x: len(x[0]), reverse=True)
    batches = []
    for i in xrange(0, len(data), batch_size):
        #idxs, idys
        input_lst = create_input(data[i:i+batch_size], padding_id)
        batches.append(input_lst)
    if shuffle:
        idx = np.random.permutation(len(batches))
        new_batches = [batches[i] for i in idx]
        new_data = reduce(operator.add, [data[i*batch_size:(i+1)*batch_size] for i in idx])
        batches, data = new_batches, new_data
        assert len(new_data) == len(data)
    if not label:
        # set all label to 0
        for b in batches:
            b[1][:] = 0

    return batches, data
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号