def embedding_lookup_sparse_sumexp(params, sp_ids,
name=None):
segment_ids = sp_ids.indices[:, 0]
if segment_ids.dtype != tf.int32:
segment_ids = tf.cast(segment_ids, tf.int32)
ids = sp_ids.values
ids, idx = tf.unique(ids)
embeddings = tf.nn.embedding_lookup(params, ids)
embeddings = tf.exp(embeddings)
embeddings = tf.sparse_segment_sum(embeddings, idx, segment_ids,
name=name)
return embeddings
评论列表
文章目录