def pad_batch(mini_batch):
mini_batch_size = len(mini_batch)
# print mini_batch.shape
# print mini_batch
max_sent_len1 = int(np.max([len(x[0]) for x in mini_batch]))
max_sent_len2 = int(np.max([len(x[1]) for x in mini_batch]))
# print max_sent_len1, max_sent_len2
# max_token_len = int(np.mean([len(val) for sublist in mini_batch for val in sublist]))
main_matrix1 = np.zeros((mini_batch_size, max_sent_len1), dtype= np.int)
main_matrix2 = np.zeros((mini_batch_size, max_sent_len2), dtype= np.int)
for idx1, i in enumerate(mini_batch):
for idx2, j in enumerate(i[0]):
try:
main_matrix1[i,j] = j
except IndexError:
pass
for idx1, i in enumerate(mini_batch):
for idx2, j in enumerate(i[1]):
try:
main_matrix2[i,j] = j
except IndexError:
pass
main_matrix1_t = Variable(torch.from_numpy(main_matrix1))
main_matrix2_t = Variable(torch.from_numpy(main_matrix2))
# print main_matrix1_t.size()
# print main_matrix2_t.size()
return [main_matrix1_t, main_matrix2_t]
# return [Variable(torch.cat((main_matrix1_t, main_matrix2_t), 0))
# def pad_batch(mini_batch):
# # print mini_batch
# # print type(mini_batch)
# # print mini_batch.shape
# # for i, _ in enumerate(mini_batch):
# # print i, _
# return [Variable(torch.from_numpy(np.asarray(_))) for _ in mini_batch[0]]
评论列表
文章目录