def _add_n_or_sum(terms):
# add_n works for Tensors of the same dtype and shape
shape = terms[0].get_shape()
dtype = terms[0].dtype
if all(term.get_shape().is_fully_defined() and
term.get_shape().is_compatible_with(shape) and term.dtype == dtype
for term in terms):
return math_ops.add_n(terms)
else:
return sum(terms)
评论列表
文章目录