def select_present(x, presence, batch_size=1, name='select_present'):
with tf.variable_scope(name):
presence = 1 - tf.to_int32(presence) # invert mask
bs = x.get_shape()[0]
if bs != None: # here type(bs) is tf.Dimension and == is ok
batch_size = int(bs)
num_partitions = 2 * batch_size
r = tf.range(0, num_partitions, 2)
r.set_shape(tf.TensorShape(batch_size))
r = broadcast_against(r, presence)
presence += r
selected = tf.dynamic_partition(x, presence, num_partitions)
selected = tf.concat(axis=0, values=selected)
selected = tf.reshape(selected, tf.shape(x))
return selected
评论列表
文章目录