train_lstm.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:MemNN 作者: berlino 项目源码 文件源码
def train(epoch): 
    for e_ in range(epoch):
    if (e_ + 1) % 10 == 0:
            adjust_learning_rate(optimizer, e_)
        cnt = 0
        loss = Variable(torch.Tensor([0]))
        for i_q, i_w, i_e_p, i_a in zip(train_q, train_w, train_e_p, train_a):
            cnt += 1
            i_q = i_q.unsqueeze(0) # add dimension
            probs = model.forward(i_q, i_w, i_e_p)
            i_a = Variable(i_a)
            curr_loss = loss_function(probs, i_a)
            loss = torch.add(loss, torch.div(curr_loss, config.batch_size)) 

            # naive batch implemetation, the lr is divided by batch size
            if cnt % config.batch_size == 0:
                print "Training loss", loss.data.sum()
                loss.backward()
                optimizer.step()
                loss = Variable(torch.Tensor([0]))
                model.zero_grad()
            if cnt % config.valid_every == 0:
                print "Accuracy:",eval()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号