def floaty_scatter_update(ref, indices, updates, **kwargs): return tf.scatter_update(ref, tf.to_int32(indices), updates, **kwargs)