ptb_word_lm.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:Language-Model-2016 作者: feizhihui 项目源码 文件源码
def run_epoch(session, model, data, eval_op, verbose=False):
    """Runs the model on the given data."""
    epoch_size = ((len(data) // model.batch_size) - 1) // model.num_steps
    start_time = time.time()
    costs = 0.0
    accs = 0.0
    iters = 0
    # ?????????,??op:zero_state??
    # tuple(num_layors*[batch_size,size])
    lstm_state_value = session.run(model.initial_state)
    for step, (x, y) in enumerate(reader.ptb_iterator(data, model.batch_size, model.num_steps)):
        feed_dict = {}
        feed_dict[model.input_data] = x
        feed_dict[model.targets] = y
        # foreach num = num_layors
        for i, (c, h) in enumerate(model.initial_state):
            # feed shape([batch_zie=20,size=200])
            feed_dict[c] = lstm_state_value[i].c
            feed_dict[h] = lstm_state_value[i].h
        # feed_dict{x,y,c1,h1,c2,h2}
        cost, acc, lstm_state_value, _ = session.run([model.cost, model.accuracy, model.final_state, eval_op],
                                                     feed_dict)
        accs += acc
        costs += cost  # batch?????????cost
        iters += model.num_steps

        if verbose and step % (epoch_size // 10) == 10:
            print("%.3f perplexity: %.3f speed: %.0f wps" %
                  (step * 1.0 / epoch_size, np.exp(costs / iters),
                   iters * model.batch_size / (time.time() - start_time)))
            print("Accuracy:", accs / iters)

    return np.exp(costs / iters), accs / iters
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号