torch_util.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:multiNLI_encoder 作者: easonnie 项目源码 文件源码
def text_conv1d(inputs, l1, conv_filter: nn.Linear, k_size, dropout=None, list_in=False,
                gate_way=True):
    """
    :param inputs: [T * B * D] 
    :param l1:  [B]
    :param conv_filter:  [k * D_in, D_out * 2]
    :param k_size:  
    :param dropout: 
    :param padding: 
    :param list_in: 
    :return: 
    """
    k = k_size
    batch_size = l1.size(0)
    d_in = inputs.size(2) if not list_in else inputs[0].size(1)
    unit_d = conv_filter.out_features // 2
    pad_n = (k - 1) // 2

    zeros_padding = Variable(inputs[0].data.new(pad_n, d_in).zero_())

    batch_list = []
    input_list = []
    for b_i in range(batch_size):
        masked_in = inputs[:l1[b_i], b_i, :] if not list_in else inputs[b_i]
        if gate_way:
            input_list.append(masked_in)

        b_inputs = torch.cat([zeros_padding, masked_in, zeros_padding], dim=0)
        for i in range(l1[b_i]):
            # print(b_inputs[i:i+k])
            batch_list.append(b_inputs[i:i+k].view(k * d_in))

    batch_in = torch.stack(batch_list, dim=0)
    a, b = torch.chunk(conv_filter(batch_in), 2, 1)
    out = a * F.sigmoid(b)

    out_list = []
    start = 0
    for b_i in range(batch_size):
        if gate_way:
            out_list.append(torch.cat((input_list[b_i], out[start:start + l1[b_i]]), dim=1))
        else:
            out_list.append(out[start:start + l1[b_i]])

        start = start + l1[b_i]

    # max_out_list = []
    # for b_i in range(batch_size):
    #     max_out, _ = torch.max(out_list[b_i], dim=0)
    #     max_out_list.append(max_out)
    # max_out = torch.cat(max_out_list, 0)
    #
    # print(out_list)

    return out_list
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号