sparse_feature_cross_op.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def sparse_feature_cross(inputs, hashed_output=False, num_buckets=0,
                         name=None):
  """Crosses a list of Tensor or SparseTensor objects.

  See sparse_feature_cross_kernel.cc for more details.

  Args:
    inputs: List of `SparseTensor` or `Tensor` to be crossed.
    hashed_output: If true, returns the hash of the cross instead of the string.
      This will allow us avoiding string manipulations.
    num_buckets: It is used if hashed_output is true.
      output = hashed_value%num_buckets if num_buckets > 0 else hashed_value.
    name: A name prefix for the returned tensors (optional).

  Returns:
    A `SparseTensor` with the crossed features.
    Return type is string if hashed_output=False, int64 otherwise.

  Raises:
    TypeError: If the inputs aren't either SparseTensor or Tensor.
  """
  if not isinstance(inputs, list):
    raise TypeError("Inputs must be a list")
  if not all(isinstance(i, ops.SparseTensor) or
             isinstance(i, ops.Tensor) for i in inputs):
    raise TypeError("All inputs must be SparseTensors")

  sparse_inputs = [i for i in inputs if isinstance(i, ops.SparseTensor)]
  dense_inputs = [i for i in inputs if not isinstance(i, ops.SparseTensor)]

  indices = [sp_input.indices for sp_input in sparse_inputs]
  values = [sp_input.values for sp_input in sparse_inputs]
  shapes = [sp_input.shape for sp_input in sparse_inputs]
  out_type = dtypes.int64 if hashed_output else dtypes.string

  internal_type = dtypes.string
  for i in range(len(values)):
    if values[i].dtype != dtypes.string:
      values[i] = math_ops.to_int64(values[i])
      internal_type = dtypes.int64
  for i in range(len(dense_inputs)):
    if dense_inputs[i].dtype != dtypes.string:
      dense_inputs[i] = math_ops.to_int64(dense_inputs[i])
      internal_type = dtypes.int64

  indices_out, values_out, shape_out = (
      _sparse_feature_cross_op.sparse_feature_cross(indices,
                                                    values,
                                                    shapes,
                                                    dense_inputs,
                                                    hashed_output,
                                                    num_buckets,
                                                    out_type=out_type,
                                                    internal_type=internal_type,
                                                    name=name))
  return ops.SparseTensor(indices_out, values_out, shape_out)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号