def main():
parser = argparse.ArgumentParser()
parser.add_argument('--save', type=str, default='work')
parser.add_argument('--nEpoch', type=int, default=100)
# parser.add_argument('--testBatchSz', type=int, default=2048)
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--model', type=str, default="picnn",
choices=['picnn', 'ficnn'])
parser.add_argument('--dataset', type=str, default="moons",
choices=['moons', 'circles', 'linear'])
parser.add_argument('--noncvx', action='store_true')
args = parser.parse_args()
npr.seed(args.seed)
tf.set_random_seed(args.seed)
setproctitle.setproctitle('bamos.icnn.synthetic.{}.{}'.format(args.model, args.dataset))
save = os.path.join(os.path.expanduser(args.save),
"{}.{}".format(args.model, args.dataset))
if os.path.isdir(save):
shutil.rmtree(save)
os.makedirs(save, exist_ok=True)
if args.dataset == "moons":
(dataX, dataY) = make_moons(noise=0.3, random_state=0)
elif args.dataset == "circles":
(dataX, dataY) = make_circles(noise=0.2, factor=0.5, random_state=0)
dataY = 1.-dataY
elif args.dataset == "linear":
(dataX, dataY) = make_classification(n_features=2, n_redundant=0, n_informative=2,
random_state=1, n_clusters_per_class=1)
rng = np.random.RandomState(2)
dataX += 2 * rng.uniform(size=dataX.shape)
else:
assert(False)
dataY = dataY.reshape((-1, 1)).astype(np.float32)
nData = dataX.shape[0]
nFeatures = dataX.shape[1]
nLabels = 1
nXy = nFeatures + nLabels
config = tf.ConfigProto() #log_device_placement=False)
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
model = Model(nFeatures, nLabels, sess, args.model, nGdIter=30)
model.train(args, dataX, dataY)
评论列表
文章目录