def train():
"""
Train both generator and discriminator
:return:
"""
# Load data
print 'Loading training data...'
with open('../saliency-2016-lsun/validationSample240x320.pkl', 'rb') as f:
# with open(TRAIN_DATA_DIR, 'rb') as f:
train_data = pickle.load(f)
print '-->done!'
print 'Loading validation data...'
with open('../saliency-2016-lsun/validationSample240x320.pkl', 'rb') as f:
# with open(VALIDATION_DATA_DIR, 'rb') as f:
validation_data = pickle.load(f)
print '-->done!'
# Choose a random sample to monitor the training
num_random = random.choice(range(len(validation_data)))
validation_sample = validation_data[num_random]
cv2.imwrite('./' + DIR_TO_SAVE + '/validationRandomSaliencyGT.png', validation_sample.saliency.data)
cv2.imwrite('./' + DIR_TO_SAVE + '/validationRandomImage.png', cv2.cvtColor(validation_sample.image.data,
cv2.COLOR_RGB2BGR))
# Create network
if flag == 'salgan':
model = ModelSALGAN(INPUT_SIZE[0], INPUT_SIZE[1])
# Load a pre-trained model
# load_weights(net=model.net['output'], path="nss/gen_", epochtoload=15)
# load_weights(net=model.discriminator['prob'], path="test_dialted/disrim_", epochtoload=54)
salgan_batch_iterator(model, train_data, validation_sample.image.data)
elif flag == 'bce':
model = ModelBCE(INPUT_SIZE[0], INPUT_SIZE[1])
# Load a pre-trained model
# load_weights(net=model.net['output'], path='test/gen_', epochtoload=15)
bce_batch_iterator(model, train_data, validation_sample.image.data)
else:
print "Invalid input argument."
评论列表
文章目录