def train(args):
data_loader = DataLoader(args.data_dir, args.batch_size, args.seq_length)
with open(os.path.join(args.save_dir, 'config.pkl'), 'w') as f:
cPickle.dump(args, f)
model = Model(args)
with tf.Session() as sess:
tf.initialize_all_variables().run()
saver = tf.train.Saver(tf.all_variables())
for e in xrange(args.num_epochs):
sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e)))
data_loader.reset_batch_pointer()
state = model.initial_state.eval()
for b in xrange(data_loader.num_batches):
start = time.time()
x, y = data_loader.next_batch()
#print(x, '->', y)
#import sys; sys.exit();
feed = {
model.input_data: x,
model.targets: y,
model.initial_state: state
}
train_loss, state, _ = sess.run(\
[model.cost, model.final_state, model.train_op], feed)
end = time.time()
print "{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
.format(e * data_loader.num_batches + b,
args.num_epochs * data_loader.num_batches,
e, train_loss, end - start)
if (e * data_loader.num_batches + b) % args.save_every == 0:
checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')
saver.save(sess, checkpoint_path, global_step = e * data_loader.num_batches + b)
print "model saved to {}".format(checkpoint_path)
python类DataLoader()的实例源码
def train(args):
data_loader = DataLoader(args.data_dir, args.batch_size, args.seq_length)
with open(os.path.join(args.save_dir, 'config.pkl'), 'w') as f:
cPickle.dump(args, f)
model = Model(args)
with tf.Session() as sess:
tf.initialize_all_variables().run()
saver = tf.train.Saver(tf.all_variables())
for e in xrange(args.num_epochs):
sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e)))
data_loader.reset_batch_pointer()
state = model.initial_state.eval()
for b in xrange(data_loader.num_batches):
start = time.time()
x, y = data_loader.next_batch()
#print(x, '->', y)
#import sys; sys.exit();
feed = {
model.input_data: x,
model.targets: y,
model.initial_state: state
}
train_loss, state, _ = sess.run(\
[model.cost, model.final_state, model.train_op], feed)
end = time.time()
print "{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
.format(e * data_loader.num_batches + b,
args.num_epochs * data_loader.num_batches,
e, train_loss, end - start)
if (e * data_loader.num_batches + b) % args.save_every == 0:
checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')
saver.save(sess, checkpoint_path, global_step = e * data_loader.num_batches + b)
print "model saved to {}".format(checkpoint_path)
def train(args):
datasets = range(4)
# Remove the leaveDataset from datasets
datasets.remove(args.leaveDataset)
# Create the data loader object. This object would preprocess the data in terms of
# batches each of size args.batch_size, of length args.seq_length
data_loader = DataLoader(args.batch_size, args.seq_length, datasets, forcePreProcess=True)
# Save the arguments int the config file
with open(os.path.join('save_lstm', 'config.pkl'), 'wb') as f:
pickle.dump(args, f)
# Create a Vanilla LSTM model with the arguments
model = Model(args)
# Initialize a TensorFlow session
with tf.Session() as sess:
# Initialize all the variables in the graph
sess.run(tf.initialize_all_variables())
# Add all the variables to the list of variables to be saved
saver = tf.train.Saver(tf.all_variables())
# For each epoch
for e in range(args.num_epochs):
# Assign the learning rate (decayed acc. to the epoch number)
sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e)))
# Reset the pointers in the data loader object
data_loader.reset_batch_pointer()
# Get the initial cell state of the LSTM
state = sess.run(model.initial_state)
# For each batch in this epoch
for b in range(data_loader.num_batches):
# Tic
start = time.time()
# Get the source and target data of the current batch
# x has the source data, y has the target data
x, y = data_loader.next_batch()
# Feed the source, target data and the initial LSTM state to the model
feed = {model.input_data: x, model.target_data: y, model.initial_state: state}
# Fetch the loss of the model on this batch, the final LSTM state from the session
train_loss, state, _ = sess.run([model.cost, model.final_state, model.train_op], feed)
# Toc
end = time.time()
# Print epoch, batch, loss and time taken
print(
"{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}"
.format(
e * data_loader.num_batches + b,
args.num_epochs * data_loader.num_batches,
e,
train_loss, end - start))
# Save the model if the current epoch and batch number match the frequency
if (e * data_loader.num_batches + b) % args.save_every == 0 and ((e * data_loader.num_batches + b) > 0):
checkpoint_path = os.path.join('save_lstm', 'model.ckpt')
saver.save(sess, checkpoint_path, global_step=e * data_loader.num_batches + b)
print("model saved to {}".format(checkpoint_path))