def to_weighted_sum(self,
input_tensor,
num_outputs=1,
weight_collections=None,
trainable=True):
output, embedding_weights = _create_embedding_lookup(
input_tensor=input_tensor,
weight_tensor=None,
vocab_size=self.length,
dimension=num_outputs,
weight_collections=_add_variable_collection(weight_collections),
initializer=init_ops.zeros_initializer,
combiner=self.combiner,
trainable=trainable)
if self.ckpt_to_load_from is not None:
weights_to_restore = embedding_weights
if len(embedding_weights) == 1:
weights_to_restore = embedding_weights[0]
checkpoint_utils.init_from_checkpoint(
self.ckpt_to_load_from,
{self.tensor_name_in_ckpt: weights_to_restore})
return output, embedding_weights
评论列表
文章目录