lstm_parser_bi_fast.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:depccg 作者: masashi-y 项目源码 文件源码
def concat_examples(batch, device=None):
    if len(batch) == 0:
        raise ValueError('batch is empty')

    if device is None:
        def to_device(x):
            return x
    elif device < 0:
        to_device = cuda.to_cpu
    else:
        def to_device(x):
            return cuda.to_gpu(x, device, cuda.Stream.null)

    result = [to_device(_concat_arrays([s[0] for s in batch], -1)), # ws
              to_device(_concat_arrays([s[1] for s in batch], -1)), # ps
              to_device(_concat_arrays([s[2] for s in batch], -1)), # ss
              [s[3] for s in batch]]                                # ls

    if len(batch[0]) == 7:
        result.append([to_device(s[4]) for s in batch])            # cat_ts
        result.append([to_device(s[5]) for s in batch])            # dep_ts
        result.append(to_device(_concat_arrays([s[6] for s in batch], None))) # weights

    return tuple(result)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号