c2w2c_textgen.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:c2w2c 作者: milankinen 项目源码 文件源码
def _sample_words(model, c, maxlen, V_C, K=20):
  def predict(samples):
    context = np.array([c] * len(samples))
    prev_chars = np.zeros((len(samples), maxlen), dtype=np.int32)
    probs = np.zeros((len(samples), V_C.size), dtype=np.float32)
    for i, prev in enumerate(samples):
      for j, ch in enumerate(prev):
        prev_chars[i, j + 1] = ch + 1
    preds = model.predict_chars(context, prev_chars)
    for i, prev in enumerate(samples):
      np.copyto(probs[i], preds[i, len(prev)])
    return probs

  eow = V_C.get_index(EOW)
  best_chars, losses = beamsearch(predict, eow, k=K, maxsample=maxlen)
  best_words = []
  for word_chars in best_chars:
    word = ""
    for ch in word_chars:
      if ch == eow:
        break
      word += V_C.get_token(ch)
    best_words.append(word)
  probs = 1. / np.exp(np.array(losses))
  return best_words, probs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号