def shrink_embed(mat, word_ixs: List):
"""
Build an embedding matrix that contains only the elements in `word_ixs`,
and map `word_ixs` to tensors the index into they new embedding matrix.
Useful if you want to dropout the embeddings w/o dropping out the entire matrix
"""
all_words, out_id = tf.unique(tf.concat([tf.reshape(x, (-1,)) for x in word_ixs], axis=0))
mat = tf.gather(mat, all_words)
partitions = tf.split(out_id, [tf.reduce_prod(tf.shape(x)) for x in word_ixs])
partitions = [tf.reshape(x, tf.shape(o)) for x,o in zip(partitions, word_ixs)]
return mat, partitions
评论列表
文章目录