base.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:zhusuan 作者: thu-ml 项目源码 文件源码
def _check_input_shape(self, given):
        given = tf.convert_to_tensor(given, dtype=self.dtype)

        err_msg = "The given argument should be able to broadcast to " \
                  "match batch_shape + value_shape of the distribution."
        if (given.get_shape() and self.get_batch_shape() and
                self.get_value_shape()):
            static_sample_shape = tf.TensorShape(
                self.get_batch_shape().as_list() +
                self.get_value_shape().as_list())
            try:
                tf.broadcast_static_shape(given.get_shape(),
                                          static_sample_shape)
            except ValueError:
                raise ValueError(
                    err_msg + " ({} vs. {} + {})".format(
                        given.get_shape(), self.get_batch_shape(),
                        self.get_value_shape()))
        return given
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号