def _filter_function(n_gpus): def f(x,y): a = tf.equal( tf.mod( tf.shape(x)[0] , n_gpus ) , 0 ) b = tf.equal( tf.mod( tf.shape(y)[0] , n_gpus ) , 0 ) return tf.logical_and(a,b) return f