def main():
parser = argparse.ArgumentParser()
parser.add_argument('--save', type=str, default='work/mse')
parser.add_argument('--nEpoch', type=float, default=50)
# parser.add_argument('--trainBatchSz', type=int, default=25)
parser.add_argument('--trainBatchSz', type=int, default=70)
# parser.add_argument('--testBatchSz', type=int, default=2048)
parser.add_argument('--nGdIter', type=int, default=30)
parser.add_argument('--noncvx', action='store_true')
parser.add_argument('--seed', type=int, default=42)
# parser.add_argument('--valSplit', type=float, default=0)
args = parser.parse_args()
setproctitle.setproctitle('bamos.icnn.comp.mse')
npr.seed(args.seed)
tf.set_random_seed(args.seed)
save = os.path.expanduser(args.save)
if os.path.isdir(save):
shutil.rmtree(save)
os.makedirs(save)
ckptDir = os.path.join(save, 'ckpt')
args.ckptDir = ckptDir
if not os.path.exists(ckptDir):
os.makedirs(ckptDir)
data = olivetti.load("data/olivetti")
nTrain = data['trainX'].shape[0]
nTest = data['testX'].shape[0]
inputSz = list(data['trainX'][0].shape)
outputSz = list(data['trainY'][1].shape)
print("\n\n" + "="*40)
print("+ nTrain: {}, nTest: {}".format(nTrain, nTest))
print("+ inputSz: {}, outputSz: {}".format(inputSz, outputSz))
print("="*40 + "\n\n")
config = tf.ConfigProto() #log_device_placement=False)
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
model = Model(inputSz, outputSz, sess, args.nGdIter)
model.train(args, data['trainX'], data['trainY'], data['testX'], data['testY'])
评论列表
文章目录