def text_conv1d(inputs, l1, conv_filter: nn.Linear, k_size, dropout=None, list_in=False,
gate_way=True):
"""
:param inputs: [T * B * D]
:param l1: [B]
:param conv_filter: [k * D_in, D_out * 2]
:param k_size:
:param dropout:
:param padding:
:param list_in:
:return:
"""
k = k_size
batch_size = l1.size(0)
d_in = inputs.size(2) if not list_in else inputs[0].size(1)
unit_d = conv_filter.out_features // 2
pad_n = (k - 1) // 2
zeros_padding = Variable(inputs[0].data.new(pad_n, d_in).zero_())
batch_list = []
input_list = []
for b_i in range(batch_size):
masked_in = inputs[:l1[b_i], b_i, :] if not list_in else inputs[b_i]
if gate_way:
input_list.append(masked_in)
b_inputs = torch.cat([zeros_padding, masked_in, zeros_padding], dim=0)
for i in range(l1[b_i]):
# print(b_inputs[i:i+k])
batch_list.append(b_inputs[i:i+k].view(k * d_in))
batch_in = torch.stack(batch_list, dim=0)
a, b = torch.chunk(conv_filter(batch_in), 2, 1)
out = a * F.sigmoid(b)
out_list = []
start = 0
for b_i in range(batch_size):
if gate_way:
out_list.append(torch.cat((input_list[b_i], out[start:start + l1[b_i]]), dim=1))
else:
out_list.append(out[start:start + l1[b_i]])
start = start + l1[b_i]
# max_out_list = []
# for b_i in range(batch_size):
# max_out, _ = torch.max(out_list[b_i], dim=0)
# max_out_list.append(max_out)
# max_out = torch.cat(max_out_list, 0)
#
# print(out_list)
return out_list
评论列表
文章目录