def main():
parser = argparse.ArgumentParser()
parser.add_argument('--save', type=str, default='work/mse.ebundle')
parser.add_argument('--nEpoch', type=float, default=50)
parser.add_argument('--nBundleIter', type=int, default=30)
# parser.add_argument('--trainBatchSz', type=int, default=25)
parser.add_argument('--trainBatchSz', type=int, default=70)
# parser.add_argument('--testBatchSz', type=int, default=2048)
parser.add_argument('--noncvx', action='store_true')
parser.add_argument('--seed', type=int, default=42)
# parser.add_argument('--valSplit', type=float, default=0)
args = parser.parse_args()
assert(not args.noncvx)
setproctitle.setproctitle('bamos.icnn.comp.mse.ebundle')
npr.seed(args.seed)
tf.set_random_seed(args.seed)
save = os.path.expanduser(args.save)
if os.path.isdir(save):
shutil.rmtree(save)
os.makedirs(save)
ckptDir = os.path.join(save, 'ckpt')
args.ckptDir = ckptDir
if not os.path.exists(ckptDir):
os.makedirs(ckptDir)
data = olivetti.load("data/olivetti")
# eps = 1e-8
# data['trainX'] = data['trainX'].clip(eps, 1.-eps)
# data['trainY'] = data['trainY'].clip(eps, 1.-eps)
# data['testX'] = data['testX'].clip(eps, 1.-eps)
# data['testY'] = data['testY'].clip(eps, 1.-eps)
nTrain = data['trainX'].shape[0]
nTest = data['testX'].shape[0]
inputSz = list(data['trainX'][0].shape)
outputSz = list(data['trainY'][1].shape)
print("\n\n" + "="*40)
print("+ nTrain: {}, nTest: {}".format(nTrain, nTest))
print("+ inputSz: {}, outputSz: {}".format(inputSz, outputSz))
print("="*40 + "\n\n")
config = tf.ConfigProto() #log_device_placement=False)
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
model = Model(inputSz, outputSz, sess)
model.train(args, data['trainX'], data['trainY'], data['testX'], data['testY'])
评论列表
文章目录