univariate.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:zhusuan 作者: thu-ml 项目源码 文件源码
def _log_prob(self, given):
        logits = self.logits

        def _broadcast(given, logits):
            # static shape has been checked in base class.
            ones_ = tf.ones(tf.shape(logits)[:-1], self.dtype)
            if logits.get_shape():
                ones_.set_shape(logits.get_shape()[:-1])
            given *= ones_
            logits *= tf.ones_like(tf.expand_dims(given, -1), self.param_dtype)
            return given, logits

        def _is_same_dynamic_shape(given, logits):
            return tf.cond(
                tf.equal(tf.rank(given), tf.rank(logits) - 1),
                lambda: tf.reduce_all(tf.equal(
                    tf.concat([tf.shape(given), tf.shape(logits)[:-1]], 0),
                    tf.concat([tf.shape(logits)[:-1], tf.shape(given)], 0))),
                lambda: tf.convert_to_tensor(False, tf.bool))

        if not (given.get_shape() and logits.get_shape()):
            given, logits = _broadcast(given, logits)
        else:
            if given.get_shape().ndims != logits.get_shape().ndims - 1:
                given, logits = _broadcast(given, logits)
            elif given.get_shape().is_fully_defined() and \
                    logits.get_shape()[:-1].is_fully_defined():
                if given.get_shape() != logits.get_shape()[:-1]:
                    given, logits = _broadcast(given, logits)
            else:
                # Below code seems to induce a BUG when this function is
                # called in HMC. Probably due to tensorflow's not supporting
                # control flow edge from an op inside the body to outside.
                # We should further fix this.
                #
                # given, logits = tf.cond(
                #     is_same_dynamic_shape(given, logits),
                #     lambda: (given, logits),
                #     lambda: _broadcast(given, logits, 'given', 'logits'))
                given, logits = _broadcast(given, logits)

        # `labels` type of `sparse_softmax_cross_entropy_with_logits` must be
        # int32 or int64
        if self.dtype == tf.float32:
            given = tf.cast(given, dtype=tf.int32)
        elif self.dtype == tf.float64:
            given = tf.cast(given, dtype=tf.int64)
        elif self.dtype not in [tf.int32, tf.int64]:
            given = tf.cast(given, tf.int32)
        log_p = -tf.nn.sparse_softmax_cross_entropy_with_logits(labels=given,
                                                                logits=logits)
        if given.get_shape() and logits.get_shape():
            log_p.set_shape(tf.broadcast_static_shape(given.get_shape(),
                                                      logits.get_shape()[:-1]))
        return log_p
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号