def build_train_func(rank=0, **kwargs):
print("rank: {} Building model".format(rank))
resnet = build_resnet()
print("Building training function")
x = T.ftensor4('x')
y = T.imatrix('y')
prob = L.get_output(resnet['prob'], x, deterministic=False)
loss = T.nnet.categorical_crossentropy(prob, y.flatten()).mean()
params = L.get_all_params(resnet.values(), trainable=True)
sgd_updates = updates.sgd(loss, params, learning_rate=1e-4)
# make a function to compute and store the raw gradient
f_train = theano.function(inputs=[x, y],
outputs=loss, # (assumes this is an avg)
updates=sgd_updates)
return f_train, "original"