def infer_column_schema_from_tensor(tensor):
"""Infer a ColumnSchema from a tensor."""
if isinstance(tensor, tf.SparseTensor):
# For SparseTensor, there's insufficient information to distinguish between
# ListColumnRepresentation and SparseColumnRepresentation. So we just guess
# the former, and callers are expected to handle the latter case on their
# own (e.g. by requiring the user to provide the schema). This is a policy
# motivated by the prevalence of VarLenFeature in current tf.Learn code.
axes = [Axis(None)]
representation = ListColumnRepresentation()
else:
axes = _shape_to_axes(tensor.get_shape(),
remove_batch_dimension=True)
representation = FixedColumnRepresentation()
return ColumnSchema(tensor.dtype, axes, representation)
评论列表
文章目录