def train_loop():
# Trainer
graph_generated = False
while True:
while data_q.empty():
time.sleep(0.1)
inp = data_q.get()
if inp == 'end': # quit
res_q.put('end')
break
elif inp == 'train': # restart training
res_q.put('train')
model.train = True
continue
elif inp == 'val': # start validation
res_q.put('val')
serializers.save_npz(args.out, model)
serializers.save_npz(args.outstate, optimizer)
model.train = False
continue
volatile = 'off' if model.train else 'on'
x = chainer.Variable(xp.asarray(inp[0]), volatile=volatile)
t = chainer.Variable(xp.asarray(inp[1]), volatile=volatile)
if model.train:
optimizer.update(model, x, t)
if not graph_generated:
with open('graph.dot', 'w') as o:
o.write(computational_graph.build_computational_graph(
(model.loss,)).dump())
print('generated graph', file=sys.stderr)
graph_generated = True
else:
model(x, t)
res_q.put((float(model.loss.data), float(model.accuracy.data)))
del x, t
# Invoke threads
评论列表
文章目录