def train_network(num_epochs, num_steps, state_size=4):
sess.run(tf.initialize_all_variables())
# print("--- min for graph building ---",(time.time() - start_time)/60.0)
# start_time = time.time()
training_losses = []
# X_test, Y_test = genTestData(num_steps, num_test_runs, num_classes)
X_test, Y_test = getTestData()
for idx, (X_epoch,Y_epoch) in enumerate(genEpochs(num_epochs, num_batches, num_steps, batch_size, num_classes, copy_len)):
training_loss = 0
acc = 0
training_state = [np.zeros((batch_size, state_size)) for i in range(num_stacked)]
print("EPOCH %d" % idx)
for batch in tqdm(range(len(X_epoch))):
X = X_epoch[batch]
Y = Y_epoch[batch]
(train_step_, loss_, train_summary_) = sess.run([train_step, loss, train_summary],
feed_dict={x:X, y:Y},
options=run_options, run_metadata=run_metadata)
training_loss += loss_
train_writer.add_summary(train_summary_, idx)
(test_loss, test_summary_, accuracy_) = sess.run(
[loss, test_summary, accuracy],
feed_dict={x:X_test, y:Y_test},
options=run_options, run_metadata=run_metadata)
train_writer.add_summary(test_summary_, idx)
training_loss = training_loss/num_batches
print("train loss:", training_loss, "test loss:", test_loss, "test accuracy:", accuracy_)
training_loss = 0
tl = timeline.Timeline(run_metadata.step_stats)
ctf = tl.generate_chrome_trace_format()
with open('timeline_add.json', 'w') as f:
f.write(ctf)
评论列表
文章目录