utils.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:emoatt 作者: epochx 项目源码 文件源码
def _extract_argmax_and_one_hot(one_hot_size,
                                output_projection=None):
  """Get a loop_function that extracts the previous symbol and build a one-hot vector for it.

  Args:
    one_hot_size: total size of one-hot vector.
    output_projection: None or a pair (W, B). If provided, each fed previous
      output will first be multiplied by W and added B.
    update_embedding: Boolean; if False, the gradients will not propagate
      through the embeddings.

  Returns:
    A loop function.
  """
  def loop_function(prev, _):
    if output_projection is not None:
      prev = nn_ops.xw_plus_b(prev, output_projection[0], output_projection[1])
    prev_symbol = math_ops.argmax(prev, 1)
    # Note that gradients will not propagate through the second parameter of
    # embedding_lookup.
    emb_prev = tf.one_hot(prev_symbol, one_hot_size)
    return emb_prev

  return loop_function
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号