icnn.back.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:icnn 作者: locuslab 项目源码 文件源码
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--save', type=str, default='work/mse')
    parser.add_argument('--nEpoch', type=float, default=50)
    # parser.add_argument('--trainBatchSz', type=int, default=25)
    parser.add_argument('--trainBatchSz', type=int, default=70)
    # parser.add_argument('--testBatchSz', type=int, default=2048)
    parser.add_argument('--nGdIter', type=int, default=30)
    parser.add_argument('--noncvx', action='store_true')
    parser.add_argument('--seed', type=int, default=42)
    # parser.add_argument('--valSplit', type=float, default=0)

    args = parser.parse_args()

    setproctitle.setproctitle('bamos.icnn.comp.mse')

    npr.seed(args.seed)
    tf.set_random_seed(args.seed)

    save = os.path.expanduser(args.save)
    if os.path.isdir(save):
        shutil.rmtree(save)
    os.makedirs(save)
    ckptDir = os.path.join(save, 'ckpt')
    args.ckptDir = ckptDir
    if not os.path.exists(ckptDir):
        os.makedirs(ckptDir)

    data = olivetti.load("data/olivetti")

    nTrain = data['trainX'].shape[0]
    nTest = data['testX'].shape[0]

    inputSz = list(data['trainX'][0].shape)
    outputSz = list(data['trainY'][1].shape)

    print("\n\n" + "="*40)
    print("+ nTrain: {}, nTest: {}".format(nTrain, nTest))
    print("+ inputSz: {}, outputSz: {}".format(inputSz, outputSz))
    print("="*40 + "\n\n")

    config = tf.ConfigProto() #log_device_placement=False)
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        model = Model(inputSz, outputSz, sess, args.nGdIter)
        model.train(args, data['trainX'], data['trainY'], data['testX'], data['testY'])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号