def train_data(mini_batch, feature_batch, targets, word_attn_model, mix_softmax, optimizer, criterion, do_step=True, cuda=False, lstm=False):
state_word = word_attn_model.init_hidden()
optimizer.zero_grad()
#print("inside cuda", cuda)
if cuda:
if lstm:
state_word[0] = state_word[0].cuda()
state_word[1] = state_word[1].cuda()
else:
state_word = state_word.cuda()
mini_batch[0] = mini_batch[0].cuda()
mini_batch[1] = mini_batch[1].cuda()
feature_batch = feature_batch.cuda()
# word_optimizer.zero_grad()
# mix_optimizer.zero_grad()
# print mini_batch[0].unsqueeze(1).size()
# print mini_batch[1].unsqueeze(1).size()
s1, state_word, _ = word_attn_model(mini_batch[0].transpose(0,1), state_word)
s2, state_word, _ = word_attn_model(mini_batch[1].transpose(0,1), state_word)
s = torch.cat((s1, s2),0)
y_pred = mix_softmax(s, feature_batch)
# y_pred = mix_softmax(feature_batch)
if cuda:
y_pred = y_pred.cuda()
targets = targets.cuda()
# print y_pred.size(), targets.size(), "pred", y_pred, "targets", targets
loss = criterion(y_pred, targets)
loss.backward()
if do_step:
optimizer.step()
# word_optimizer.step()
# mix_optimizer.step()
grad_norm = torch.nn.utils.clip_grad_norm(optimizer._var_list, 1.0 * 1e20)
return loss.data[0], grad_norm
评论列表
文章目录