seq2seq_helpers.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:DeepDeepParser 作者: janmbuys 项目源码 文件源码
def gather_nd_states(inputs, inds, batch_size, input_size, state_size):
  """Gathers an embedding for each batch entry with index inds from inputs.   

  Args:
    inputs: Tensor [batch_size, input_size, state_size].
    inds: Tensor [batch_size]

  Returns:
    output: Tensor [batch_size, embedding_size]
  """
  sparse_inds = tf.transpose(tf.pack(
      [tf.range(batch_size), inds]))
  dense_inds = tf.sparse_to_dense(sparse_inds, 
      tf.pack([batch_size, input_size]),
      tf.ones(tf.pack([batch_size])))

  output_sum = tf.reduce_sum(tf.reshape(dense_inds, 
      [-1, input_size, 1, 1]) * tf.reshape(inputs, 
        [-1, input_size, 1, state_size]), [1, 2])
  output = tf.reshape(output_sum, [-1, state_size])
  return output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号