def weighted_sum(components, weights, scope=""):
# n: num_components
# b: batch_size
# c: component_size
with tf.name_scope(scope):
weight_is_batched = (weights.get_shape().ndims == 2)
if weight_is_batched:
set_batch_size = tf.shape(weights)[0]
else:
set_batch_size = None
components, is_batched = make_batch_consistent(components, set_batch_size=set_batch_size)
components = tf.pack(components) # [n x b x c]
weight_rank = weights.get_shape().ndims
assert_rank_1_or_2(weight_rank)
if weight_rank == 1:
weights = tf.reshape(weights, [-1,1,1]) # [n x 1 x 1]
elif weight_rank == 2:
weights = tf.expand_dims(tf.transpose(weights, [1, 0]),2) # [n x b x 1]
components += weights
# TODO: change this to tf.reduce_logsumexp when it is relased
w_sum = logsumexp(components, reduction_indices=0) # [b x c]
if not is_batched: w_sum = tf.squeeze(w_sum) # [c]
return w_sum
评论列表
文章目录