sync-dssm-dist.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:dssm 作者: liaha 项目源码 文件源码
def pull_batch(query_data, doc_data, batch_idx):
    query_in = query_data[batch_idx * BS:(batch_idx + 1) * BS, :]
    doc_in = doc_data[batch_idx * BS:(batch_idx + 1) * BS, :]
    cols = np.unique(np.concatenate((query_in.tocoo().col.T, doc_in.tocoo().col.T), axis=0))
    # print(query_in.shape)
    # print(doc_in.shape)
    query_in = query_in[:, cols].tocoo()
    doc_in = doc_in[:, cols].tocoo()

    query_in = tf.SparseTensorValue(
        np.transpose([np.array(query_in.row, dtype=np.int64), np.array(query_in.col, dtype=np.int64)]),
        np.array(query_in.data, dtype=np.float),
        np.array(query_in.shape, dtype=np.int64))
    doc_in = tf.SparseTensorValue(
        np.transpose([np.array(doc_in.row, dtype=np.int64), np.array(doc_in.col, dtype=np.int64)]),
        np.array(doc_in.data, dtype=np.float),
        np.array(doc_in.shape, dtype=np.int64))

    return query_in, doc_in, cols
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号