def apply_factor(tensor, *args, **kwargs):
scope = kwargs.pop("scope", "")
with tf.name_scope(scope):
n_args = len(args)
if n_args is 0:
tensor, output_size, error_symbol = tensor
return one_hot(tensor, output_size, scope=scope)
else:
tensor, args = slice_out_int_literals(tensor, list(args))
args, is_batched = make_batch_consistent(args)
tensor, output_size, error_symbol = tensor
# handle the case where all arguments were int literals
tensor_dim_sizes = [dim.value for dim in tensor.get_shape()]
if not tensor_dim_sizes:
return one_hot(tensor, output_size, scope=scope)
# Each arg is batch size x arg dim. Add dimensions to enable broadcasting.
for i, arg in enumerate(args):
for j in xrange(n_args):
if j == i: continue
args[i] = tf.expand_dims(args[i], j + 1)
# compute joint before tensor is applied
joint = 1
for arg in args:
joint = joint * arg
# prepare for unsorted_segment_sum
joint = tf.reshape(joint, (-1, np.prod(tensor_dim_sizes)))
joint = tf.transpose(joint, [1, 0]) # |tensor| x batch_size
if error_symbol is not None:
result = tf.unsorted_segment_sum(joint, tf.reshape(tensor, [-1]), output_size + 1)
# assume error bin is last bin
result = result[:output_size, :]
else:
result = tf.unsorted_segment_sum(joint, tf.reshape(tensor, [-1]), output_size)
result = tf.transpose(result, [1, 0])
if not is_batched: result = tf.squeeze(result)
return result
评论列表
文章目录