def train_resnet(
batch_size=64, # batch size on each GPU
validFreq=1,
do_valid=False,
learning_rate=1e-3,
update_rule=updates.sgd, # updates.nesterov_momentum,
n_epoch=3,
n_gpu=None, # later get this from synk.fork
**update_kwargs):
n_gpu = synk.fork(n_gpu) # (n_gpu==None will use all)
t_0 = time.time()
print("Loading data (synthetic)")
train, valid, test = load_data()
x_train, y_train = [synk.data(d) for d in train]
x_valid, y_valid = [synk.data(d) for d in valid]
x_test, y_test = [synk.data(d) for d in test]
full_mb_size = batch_size * n_gpu
learning_rate = learning_rate * n_gpu # (one technique for larger minibatches)
num_valid_slices = len(x_valid) // n_gpu // batch_size
print("Will compute validation using {} slices".format(num_valid_slices))
print("Building model")
resnet = build_resnet()
params = L.get_all_params(resnet.values(), trainable=True)
f_train_minibatch, f_predict = build_training(resnet, params, update_rule,
learning_rate=learning_rate,
**update_kwargs)
synk.distribute()
synk.broadcast(params) # (ensure all GPUs have same values)
t_last = t_1 = time.time()
print("Total setup time: {:,.1f} s".format(t_1 - t_0))
print("Starting training")
for ep in range(n_epoch):
train_loss = 0.
i = 0
for mb_idxs in iter_mb_idxs(full_mb_size, len(x_train), shuffle=True):
train_loss += f_train_minibatch(x_train, y_train, batch=mb_idxs)
i += 1
train_loss /= i
print("\nEpoch: ", ep)
print("Training Loss: {:.3f}".format(train_loss))
if do_valid and ep % validFreq == 0:
valid_loss, valid_mc = f_predict(x_valid, y_valid,
num_slices=num_valid_slices)
print("Validation Loss: {:3f}, Accuracy: {:3f}".format(
float(valid_loss), float(1 - valid_mc)))
t_2 = time.time()
print("(epoch total time: {:,.1f} s)".format(t_2 - t_last))
t_last = t_2
print("\nTotal training time: {:,.1f} s".format(t_last - t_1))
评论列表
文章目录