train_rnnlm.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:TOHO_AI 作者: re53min 项目源码 文件源码
def train(train_data, vocab, n_units=300, learning_rate_decay=0.97, seq_length=20, batch_size=20,
          epochs=20, learning_rate_decay_after=5):
    # ??????????
    model = L.Classifier(GRU(len(vocab), n_units))
    model.compute_accuracy = False

    # optimizer???
    optimizer = optimizers.Adam()
    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer.GradientClipping(5))  # ?????

    whole_len = train_data.shape[0]
    jump = whole_len / batch_size
    epoch = 0
    start_at = time.time()
    cur_at = start_at
    loss = 0
    plt_loss = []

    print('going to train {} iterations'.format(jump * epochs))
    for seq in range(jump * epochs):

        input_batch = np.array([train_data[(jump * j + seq) % whole_len]
                                for j in range(batch_size)])
        teach_batch = np.array([train_data[(jump * j + seq + 1) % whole_len]
                                for j in range(batch_size)])
        x = Variable(input_batch.astype(np.int32), volatile=False)
        teach = Variable(teach_batch.astype(np.int32), volatile=False)

        # ????
        loss += model(x, teach)

        # ??????
        if (seq + 1) % seq_length == 0:
            now = time.time()
            plt_loss.append(loss.data)
            print('{}/{}, train_loss = {}, time = {:.2f}'.format((seq + 1) / seq_length, jump,
                                                                 loss.data / seq_length, now - cur_at))
            # open('loss', 'w').write('{}\n'.format(loss.data / seq_length))
            cur_at = now

            model.cleargrads()
            loss.backward()
            loss.unchain_backward()
            optimizer.update()
            loss = 0

        # check point
        if (seq + 1) % 10000 == 0:
            pickle.dump(copy.deepcopy(model).to_cpu(), open('check_point', 'wb'))

        if (seq + 1) % jump == 0:
            epoch += 1
            if epoch >= learning_rate_decay_after:
                # optimizer.lr *= learning_rate_decay
                print('decayed learning rate by a factor {} to {}'.format(learning_rate_decay, optimizer.lr))

        sys.stdout.flush()

    pickle.dump(copy.deepcopy(model).to_cpu(), open('rnnlm_model', 'wb'))
    plot_loss(plt_loss)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号