nn1.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:YellowFin_Pytorch 作者: JianGoForIt 项目源码 文件源码
def train_data(mini_batch, feature_batch, targets, word_attn_model, mix_softmax, optimizer, criterion, do_step=True, cuda=False, lstm=False):
    state_word = word_attn_model.init_hidden()
    optimizer.zero_grad()

    #print("inside cuda", cuda)

    if cuda:
        if lstm:
            state_word[0] = state_word[0].cuda()
            state_word[1] = state_word[1].cuda()
        else:
            state_word = state_word.cuda()
        mini_batch[0] = mini_batch[0].cuda()
        mini_batch[1] = mini_batch[1].cuda()
        feature_batch = feature_batch.cuda()
#     word_optimizer.zero_grad()
#     mix_optimizer.zero_grad()
#     print mini_batch[0].unsqueeze(1).size()
#     print mini_batch[1].unsqueeze(1).size()
    s1, state_word, _ = word_attn_model(mini_batch[0].transpose(0,1), state_word)
    s2, state_word, _ = word_attn_model(mini_batch[1].transpose(0,1), state_word)
    s = torch.cat((s1, s2),0)

    y_pred = mix_softmax(s, feature_batch)
#     y_pred = mix_softmax(feature_batch)
    if cuda:
        y_pred = y_pred.cuda()
        targets = targets.cuda() 

    # print y_pred.size(), targets.size(), "pred", y_pred, "targets", targets

    loss = criterion(y_pred, targets)
    loss.backward()

    if do_step:
        optimizer.step()
#     word_optimizer.step()
#     mix_optimizer.step()
    grad_norm = torch.nn.utils.clip_grad_norm(optimizer._var_list, 1.0 * 1e20)

    return loss.data[0], grad_norm
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号