train_char.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:Tree-LSTM-LM 作者: vgene 项目源码 文件源码
def run_epoch(model, reader, criterion, is_train=False, use_cuda=False, lr=0.01):
    """
        reader: data provider
        criterion: loss calculation 
    """
    # if is_train:
    #     model.train()
    # else:
    #     model.eval()

    epoch_size = ((reader.file_length // model.batch_size)-1) // model.seq_length

    hidden = model.init_hidden()

    iters = 0
    costs = 0
    for steps, (inputs, targets) in tqdm.tqdm(enumerate(reader.iterator_char(model.batch_size, model.seq_length))):
        #print(len(inputs)) 
        model.optimizer.zero_grad()
        inputs = Variable(torch.from_numpy(inputs.astype(np.int64)).transpose(0,1).contiguous())
        targets = Variable(torch.from_numpy(targets.astype(np.int64)).transpose(0,1).contiguous())
        if use_cuda:
            inputs = inputs.cuda()
            targets = targets.cuda()
        targets = torch.squeeze(targets.view(-1, model.batch_size*model.seq_length))
        hidden = repackage_hidden(hidden, use_cuda=use_cuda)
        outputs, hidden = model(inputs, hidden)

        loss = criterion(outputs.view(-1, model.vocab_size), targets)
        costs += loss.data[0] * model.seq_length

        perplexity = np.exp(costs/((steps+1)*model.seq_length))
        #print("Iter {}/{},Perplexity:{}".format(steps+1, epoch_size, perplexity))

        if is_train:
            loss.backward()
            model.optimizer.step()

    return perplexity
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号