dssm_v3.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:dssm 作者: liaha 项目源码 文件源码
def pull_batch(query_data, doc_data, batch_idx):
    # start = time.time()
    query_in = query_data[batch_idx * BS:(batch_idx + 1) * BS, :]
    doc_in = doc_data[batch_idx * BS:(batch_idx + 1) * BS, :]

    if batch_idx == 0:
      print(query_in.getrow(53))
    query_in = query_in.tocoo()
    doc_in = doc_in.tocoo()



    query_in = tf.SparseTensorValue(
        np.transpose([np.array(query_in.row, dtype=np.int64), np.array(query_in.col, dtype=np.int64)]),
        np.array(query_in.data, dtype=np.float),
        np.array(query_in.shape, dtype=np.int64))
    doc_in = tf.SparseTensorValue(
        np.transpose([np.array(doc_in.row, dtype=np.int64), np.array(doc_in.col, dtype=np.int64)]),
        np.array(doc_in.data, dtype=np.float),
        np.array(doc_in.shape, dtype=np.int64))

    # end = time.time()
    # print("Pull_batch time: %f" % (end - start))

    return query_in, doc_in
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号