common_attention.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:tensor2tensor 作者: tensorflow 项目源码 文件源码
def conv_elems_1d(x, factor, out_depth=None):
  """Decrease the length and change the dimensionality.

  Merge/restore/compress factors positions of dim depth of the input into
  a single position of dim out_depth.
  This is basically just a strided convolution without overlapp
  between each strides.
  The original length has to be divided by factor.

  Args:
    x (tf.Tensor): shape [batch_size, length, depth]
    factor (int): Length compression factor.
    out_depth (int): Output depth

  Returns:
    tf.Tensor: shape [batch_size, length//factor, out_depth]
  """
  out_depth = out_depth or x.get_shape().as_list()[-1]
  # with tf.control_dependencies(  # Dynamic assertion
  #     [tf.assert_equal(tf.shape(x)[1] % factor, 0)]):
  x = tf.expand_dims(x, 1)  # [batch_size, 1, length, depth]
  x = tf.layers.conv2d(
      inputs=x,
      filters=out_depth,
      kernel_size=(1, factor),
      strides=(1, factor),
      padding="valid",
      data_format="channels_last",
  )  # [batch_size, 1, length//factor, out_depth]
  x = tf.squeeze(x, 1)  # [batch_size, length//factor, depth]
  return x
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号