def pull_batch(query_data, doc_data, batch_idx):
query_in = query_data[batch_idx * BS:(batch_idx + 1) * BS, :]
doc_in = doc_data[batch_idx * BS:(batch_idx + 1) * BS, :]
cols = np.unique(np.concatenate((query_in.tocoo().col.T, doc_in.tocoo().col.T), axis=0))
# print(query_in.shape)
# print(doc_in.shape)
query_in = query_in[:, cols].tocoo()
doc_in = doc_in[:, cols].tocoo()
query_in = tf.SparseTensorValue(
np.transpose([np.array(query_in.row, dtype=np.int64), np.array(query_in.col, dtype=np.int64)]),
np.array(query_in.data, dtype=np.float),
np.array(query_in.shape, dtype=np.int64))
doc_in = tf.SparseTensorValue(
np.transpose([np.array(doc_in.row, dtype=np.int64), np.array(doc_in.col, dtype=np.int64)]),
np.array(doc_in.data, dtype=np.float),
np.array(doc_in.shape, dtype=np.int64))
return query_in, doc_in, cols
评论列表
文章目录