def batch_gather(reference, indices):
'''Batchwise gathering of row indices.
The numpy equivalent is reference[np.arange(batch_size), indices].
# Arguments
reference: tensor with ndim >= 2 of shape
(batch_size, dim1, dim2, ..., dimN)
indices: 1d integer tensor of shape (batch_size) satisfiying
0 <= i < dim2 for each element i.
# Returns
A tensor with shape (batch_size, dim2, ..., dimN)
equal to reference[1:batch_size, indices]
'''
batch_size = K.shape(reference)[0]
indices = tf.pack([tf.range(batch_size), indices], axis=1)
return tf.gather_nd(reference, indices)
评论列表
文章目录