def _filter_negative_samples(labels, tensors):
"""keeps only samples with none-negative labels
Params:
-----
labels: of shape (N,)
tensors: a list of tensors, each of shape (N, .., ..) the first axis is sample number
Returns:
-----
tensors: filtered tensors
"""
# return tensors
keeps = tf.where(tf.greater_equal(labels, 0))
keeps = tf.reshape(keeps, [-1])
filtered = []
for t in tensors:
tf.assert_equal(tf.shape(t)[0], tf.shape(labels)[0])
f = tf.gather(t, keeps)
filtered.append(f)
return filtered
评论列表
文章目录