def make_batch_consistent(args, set_batch_size=None):
"""
args[i] should be either [arg_dim] or [batch_size x arg_dim]
if rank(args[i]) == 1 then tile to [batch_size x arg_dim]
"""
if set_batch_size is None:
# infer the batch_size from arg shapes
batched_args = filter(lambda x : x.get_shape().ndims > 1, args)
#batched_args = filter(lambda x : x.get_shape()[0].value is None, args)
if len(batched_args) == 0:
batch_size = 1
is_batched = False
else:
# TODO: tf.assert_equal() to check that all batch sizes are consistent?
batch_size = tf.shape(batched_args[0])[0]
is_batched = True
else:
batch_size = set_batch_size
is_batched = True
# tile any rank-1 args to a consistent batch_size
tmp_args = []
for arg in args:
arg_rank = arg.get_shape().ndims
assert_rank_1_or_2(arg_rank)
if arg_rank == 1:
tmp_args.append(tf.tile(tf.expand_dims(arg,0), [batch_size,1]))
else:
tmp_args.append(arg)
args = tmp_args
return args, is_batched
评论列表
文章目录