def _process_batch(self, batch):
# We have to call tf.abs before calling tf.mod, because tf.mod gives
# native outputs when given negative inputs.
if self._cast: batch = tf.cast(batch, tf.int32)
if self._mod_inputs: batch = tf.mod(tf.abs(batch), self._num_buckets)
return tf.gather(self._weights, batch)
评论列表
文章目录