def _create_joint_embedding_lookup(columns_to_tensors,
embedding_lookup_arguments,
num_outputs,
trainable,
weight_collections):
"""Creates an embedding lookup for all columns sharing a single weight."""
for arg in embedding_lookup_arguments:
assert arg.weight_tensor is None, (
'Joint sums for weighted sparse columns are not supported. '
'Please use weighted_sum_from_feature_columns instead.')
assert arg.combiner == 'sum', (
'Combiners other than sum are not supported for joint sums. '
'Please use weighted_sum_from_feature_columns instead.')
assert len(embedding_lookup_arguments) >= 1, (
'At least one column must be in the model.')
prev_size = 0
sparse_tensors = []
for a in embedding_lookup_arguments:
t = a.input_tensor
values = t.values + prev_size
prev_size += a.vocab_size
sparse_tensors.append(
sparse_tensor_py.SparseTensor(t.indices,
values,
t.dense_shape))
sparse_tensor = sparse_ops.sparse_concat(1, sparse_tensors)
with variable_scope.variable_scope(
None, default_name='linear_weights', values=columns_to_tensors.values()):
variable = contrib_variables.model_variable(
name='weights',
shape=[prev_size, num_outputs],
dtype=dtypes.float32,
initializer=init_ops.zeros_initializer(),
trainable=trainable,
collections=weight_collections)
if isinstance(variable, variables.Variable):
variable = [variable]
else:
variable = variable._get_variable_list() # pylint: disable=protected-access
predictions = embedding_ops.safe_embedding_lookup_sparse(
variable,
sparse_tensor,
sparse_weights=None,
combiner='sum',
name='_weights')
return variable, predictions
feature_column_ops.py 文件源码
python
阅读 16
收藏 0
点赞 0
评论 0
评论列表
文章目录