def _check_input_shape(self, given):
given = tf.convert_to_tensor(given, dtype=self.dtype)
err_msg = "The given argument should be able to broadcast to " \
"match batch_shape + value_shape of the distribution."
if (given.get_shape() and self.get_batch_shape() and
self.get_value_shape()):
static_sample_shape = tf.TensorShape(
self.get_batch_shape().as_list() +
self.get_value_shape().as_list())
try:
tf.broadcast_static_shape(given.get_shape(),
static_sample_shape)
except ValueError:
raise ValueError(
err_msg + " ({} vs. {} + {})".format(
given.get_shape(), self.get_batch_shape(),
self.get_value_shape()))
return given
评论列表
文章目录