def create_input(data, padding_id):
doc_length = [len(d[0]) for d in data]
sent_length = [len(x) for d in data for x in d[0]]
if len(sent_length) == 0: sent_length.append(0)
max_doc_len = max(1, max(doc_length))
max_sent_len = max(1, max(sent_length))
idxs = np.column_stack(
[create_doc_array(d, padding_id, max_doc_len, max_sent_len).ravel() for d in data]
)
idxs = idxs.reshape(max_sent_len, max_doc_len, len(data))
idys = np.array([d[1] for d in data], dtype="int32")
# relevance
gold_rels = np.column_stack([np.array([REL_PAD] * (max_doc_len-len(d[2])) + d[2], dtype="int32") for d in data])
assert gold_rels.shape == (max_doc_len, len(data))
for d in data: assert len(d[2]) == len(d[0])
input_lst = [idxs, idys, gold_rels]
return input_lst
评论列表
文章目录