def _FloatyGatherGrad(op, grad):
if op.inputs[0].get_shape().is_fully_defined():
dense_shape = constant_op.constant(op.inputs[0].get_shape().as_list())
values_shape = [-1] + op.inputs[0].get_shape()[1:].as_list()
else:
# op.inputs[0] can be large, so colocate the shape calculation with it.
with ops.colocate_with(op.inputs[0]):
dense_shape = array_ops.shape(op.inputs[0])
values_shape = array_ops.concat(0, [[-1], dense_shape[1:]])
values = array_ops.reshape(grad, values_shape)
indices = math_ops.to_int32(array_ops.reshape(op.inputs[1], [-1]))
return [ops.IndexedSlices(values, indices, dense_shape), None]
评论列表
文章目录