def pull_batch(query_data, doc_data, batch_idx):
# start = time.time()
query_in = query_data[batch_idx * BS:(batch_idx + 1) * BS, :]
doc_in = doc_data[batch_idx * BS:(batch_idx + 1) * BS, :]
if batch_idx == 0:
print(query_in.getrow(53))
query_in = query_in.tocoo()
doc_in = doc_in.tocoo()
query_in = tf.SparseTensorValue(
np.transpose([np.array(query_in.row, dtype=np.int64), np.array(query_in.col, dtype=np.int64)]),
np.array(query_in.data, dtype=np.float),
np.array(query_in.shape, dtype=np.int64))
doc_in = tf.SparseTensorValue(
np.transpose([np.array(doc_in.row, dtype=np.int64), np.array(doc_in.col, dtype=np.int64)]),
np.array(doc_in.data, dtype=np.float),
np.array(doc_in.shape, dtype=np.int64))
# end = time.time()
# print("Pull_batch time: %f" % (end - start))
return query_in, doc_in
评论列表
文章目录