def train(train_data, vocab, n_units=300, learning_rate_decay=0.97, seq_length=20, batch_size=20,
epochs=20, learning_rate_decay_after=5):
# ??????????
model = L.Classifier(GRU(len(vocab), n_units))
model.compute_accuracy = False
# optimizer???
optimizer = optimizers.Adam()
optimizer.setup(model)
optimizer.add_hook(chainer.optimizer.GradientClipping(5)) # ?????
whole_len = train_data.shape[0]
jump = whole_len / batch_size
epoch = 0
start_at = time.time()
cur_at = start_at
loss = 0
plt_loss = []
print('going to train {} iterations'.format(jump * epochs))
for seq in range(jump * epochs):
input_batch = np.array([train_data[(jump * j + seq) % whole_len]
for j in range(batch_size)])
teach_batch = np.array([train_data[(jump * j + seq + 1) % whole_len]
for j in range(batch_size)])
x = Variable(input_batch.astype(np.int32), volatile=False)
teach = Variable(teach_batch.astype(np.int32), volatile=False)
# ????
loss += model(x, teach)
# ??????
if (seq + 1) % seq_length == 0:
now = time.time()
plt_loss.append(loss.data)
print('{}/{}, train_loss = {}, time = {:.2f}'.format((seq + 1) / seq_length, jump,
loss.data / seq_length, now - cur_at))
# open('loss', 'w').write('{}\n'.format(loss.data / seq_length))
cur_at = now
model.cleargrads()
loss.backward()
loss.unchain_backward()
optimizer.update()
loss = 0
# check point
if (seq + 1) % 10000 == 0:
pickle.dump(copy.deepcopy(model).to_cpu(), open('check_point', 'wb'))
if (seq + 1) % jump == 0:
epoch += 1
if epoch >= learning_rate_decay_after:
# optimizer.lr *= learning_rate_decay
print('decayed learning rate by a factor {} to {}'.format(learning_rate_decay, optimizer.lr))
sys.stdout.flush()
pickle.dump(copy.deepcopy(model).to_cpu(), open('rnnlm_model', 'wb'))
plot_loss(plt_loss)
评论列表
文章目录