def apply_factor(tensor, *args, **kwargs):
scope = kwargs.pop("scope", "")
with tf.name_scope(scope):
n_args = len(args)
if n_args is 0:
tensor, output_size, error_symbol = tensor
return one_hot(tensor, output_size, scope=scope)
else:
tensor, args = slice_out_int_literals(tensor, list(args))
args, is_batched = make_batch_consistent(args)
tensor, output_size, error_symbol = tensor
# handle the case where all arguments were int literals
tensor_dim_sizes = [dim.value for dim in tensor.get_shape()]
if not tensor_dim_sizes:
return one_hot(tensor, output_size, scope=scope)
# Each arg is batch size x arg dim. Add dimensions to enable broadcasting.
for i, arg in enumerate(args):
for j in range(len(args)):
if j == i: continue
args[i] = tf.expand_dims(args[i], j + 1)
# compute joint before tensor is applied
joint = 0
for arg in args:
joint = joint + arg
# prepare for unsorted_segment_sum
joint = tf.reshape(joint, (-1, np.prod(tensor_dim_sizes)))
joint = tf.transpose(joint, [1, 0]) # |tensor| x batch_size
flat_tensor = tf.reshape(tensor, [-1])
if error_symbol is not None:
to_logsumexp = tf.dynamic_partition(joint, flat_tensor, output_size + 1)
del to_logsumexp[error_symbol]
else:
to_logsumexp = tf.dynamic_partition(joint, flat_tensor, output_size)
result = tf.pack(
map(lambda x : logsumexp(x, reduction_indices=0), to_logsumexp)
)
result = tf.transpose(result, [1, 0])
if not is_batched: result = tf.squeeze(result)
return result
评论列表
文章目录