embedder.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:document-qa 作者: allenai 项目源码 文件源码
def shrink_embed(mat, word_ixs: List):
    """
    Build an embedding matrix that contains only the elements in `word_ixs`,
    and map `word_ixs` to tensors the index into they new embedding matrix.
    Useful if you want to dropout the embeddings w/o dropping out the entire matrix
    """
    all_words, out_id = tf.unique(tf.concat([tf.reshape(x, (-1,)) for x in word_ixs], axis=0))
    mat = tf.gather(mat, all_words)
    partitions = tf.split(out_id, [tf.reduce_prod(tf.shape(x)) for x in word_ixs])
    partitions = [tf.reshape(x, tf.shape(o)) for x,o in zip(partitions, word_ixs)]
    return mat, partitions
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号