def _log_prob(self, given):
logits = self.logits
def _broadcast(given, logits):
# static shape has been checked in base class.
ones_ = tf.ones(tf.shape(logits)[:-1], self.dtype)
if logits.get_shape():
ones_.set_shape(logits.get_shape()[:-1])
given *= ones_
logits *= tf.ones_like(tf.expand_dims(given, -1), self.param_dtype)
return given, logits
def _is_same_dynamic_shape(given, logits):
return tf.cond(
tf.equal(tf.rank(given), tf.rank(logits) - 1),
lambda: tf.reduce_all(tf.equal(
tf.concat([tf.shape(given), tf.shape(logits)[:-1]], 0),
tf.concat([tf.shape(logits)[:-1], tf.shape(given)], 0))),
lambda: tf.convert_to_tensor(False, tf.bool))
if not (given.get_shape() and logits.get_shape()):
given, logits = _broadcast(given, logits)
else:
if given.get_shape().ndims != logits.get_shape().ndims - 1:
given, logits = _broadcast(given, logits)
elif given.get_shape().is_fully_defined() and \
logits.get_shape()[:-1].is_fully_defined():
if given.get_shape() != logits.get_shape()[:-1]:
given, logits = _broadcast(given, logits)
else:
# Below code seems to induce a BUG when this function is
# called in HMC. Probably due to tensorflow's not supporting
# control flow edge from an op inside the body to outside.
# We should further fix this.
#
# given, logits = tf.cond(
# is_same_dynamic_shape(given, logits),
# lambda: (given, logits),
# lambda: _broadcast(given, logits, 'given', 'logits'))
given, logits = _broadcast(given, logits)
# `labels` type of `sparse_softmax_cross_entropy_with_logits` must be
# int32 or int64
if self.dtype == tf.float32:
given = tf.cast(given, dtype=tf.int32)
elif self.dtype == tf.float64:
given = tf.cast(given, dtype=tf.int64)
elif self.dtype not in [tf.int32, tf.int64]:
given = tf.cast(given, tf.int32)
log_p = -tf.nn.sparse_softmax_cross_entropy_with_logits(labels=given,
logits=logits)
if given.get_shape() and logits.get_shape():
log_p.set_shape(tf.broadcast_static_shape(given.get_shape(),
logits.get_shape()[:-1]))
return log_p
评论列表
文章目录