def _get_weight_tensor(features, weight_column_name):
"""Returns the weight tensor of shape [batch_size] or 1."""
if weight_column_name is None:
return 1.0
else:
return array_ops.reshape(
math_ops.to_float(features[weight_column_name]),
shape=(-1,))
评论列表
文章目录