def _nelem(x): nelem = tf.reduce_sum(tf.cast(~tf.is_nan(x), tf.float32)) return tf.cast(tf.where(tf.equal(nelem, 0.), 1., nelem), x.dtype)