icnn.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:icnn 作者: locuslab 项目源码 文件源码
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--save', type=str, default='work')
    parser.add_argument('--nEpoch', type=int, default=100)
    # parser.add_argument('--testBatchSz', type=int, default=2048)
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--model', type=str, default="picnn",
                        choices=['picnn', 'ficnn'])
    parser.add_argument('--dataset', type=str, default="moons",
                        choices=['moons', 'circles', 'linear'])
    parser.add_argument('--noncvx', action='store_true')

    args = parser.parse_args()

    npr.seed(args.seed)
    tf.set_random_seed(args.seed)

    setproctitle.setproctitle('bamos.icnn.synthetic.{}.{}'.format(args.model, args.dataset))

    save = os.path.join(os.path.expanduser(args.save),
                        "{}.{}".format(args.model, args.dataset))
    if os.path.isdir(save):
        shutil.rmtree(save)
    os.makedirs(save, exist_ok=True)

    if args.dataset == "moons":
        (dataX, dataY) = make_moons(noise=0.3, random_state=0)
    elif args.dataset == "circles":
        (dataX, dataY) = make_circles(noise=0.2, factor=0.5, random_state=0)
        dataY = 1.-dataY
    elif args.dataset == "linear":
        (dataX, dataY) = make_classification(n_features=2, n_redundant=0, n_informative=2,
                                             random_state=1, n_clusters_per_class=1)
        rng = np.random.RandomState(2)
        dataX += 2 * rng.uniform(size=dataX.shape)
    else:
        assert(False)

    dataY = dataY.reshape((-1, 1)).astype(np.float32)

    nData = dataX.shape[0]
    nFeatures = dataX.shape[1]
    nLabels = 1
    nXy = nFeatures + nLabels

    config = tf.ConfigProto() #log_device_placement=False)
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        model = Model(nFeatures, nLabels, sess, args.model, nGdIter=30)
        model.train(args, dataX, dataY)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号