inference.py 文件源码

python
阅读 42 收藏 0 点赞 0 评论 0

项目:deep_srl 作者: luheng 项目源码 文件源码
def viterbi_decode(score, transition_params):
  """ Adapted from Tensorflow implementation.
  Decode the highest scoring sequence of tags outside of TensorFlow.
  This should only be used at test time.
  Args:
    score: A [seq_len, num_tags] matrix of unary potentials.
    transition_params: A [num_tags, num_tags] matrix of binary potentials.
  Returns:
    viterbi: A [seq_len] list of integers containing the highest scoring tag
        indicies.
    viterbi_score: A float containing the score for the Viterbi sequence.
  """
  trellis = numpy.zeros_like(score)
  backpointers = numpy.zeros_like(score, dtype=numpy.int32)
  trellis[0] = score[0]
  for t in range(1, score.shape[0]):
    v = numpy.expand_dims(trellis[t - 1], 1) + transition_params
    trellis[t] = score[t] + numpy.max(v, 0)
    backpointers[t] = numpy.argmax(v, 0)
  viterbi = [numpy.argmax(trellis[-1])]
  for bp in reversed(backpointers[1:]):
    viterbi.append(bp[viterbi[-1]])
  viterbi.reverse()
  viterbi_score = numpy.max(trellis[-1])
  return viterbi, viterbi_score
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号