python类False()的实例源码

train.py 文件源码 项目:iaf 作者: openai 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def get_data(problem, n_train, n_batch):

    if problem == 'cifar10':
        # Load data
        data_train, data_valid = G.misc.data.cifar10(False)
    if problem == 'svhn':
        # Load data
        data_train, data_valid = G.misc.data.svhn(False, True)
    elif problem == 'mnist':
        # Load data
        validset = False
        if validset:
            data_train, data_valid, data_test = G.misc.data.mnist_binarized(validset, False)
        else:
            data_train, data_valid = G.misc.data.mnist_binarized(validset, False)
        data_train['x'] = data_train['x'].reshape((-1,1,28,28))
        data_valid['x'] = data_valid['x'].reshape((-1,1,28,28))
    elif problem == 'lfw':
        data_train = G.misc.data.lfw(False,True)
        data_valid = G.ndict.getRows(data_train, 0, 1000)


    data_init = {'x':data_train['x'][:n_batch]}

    if n_train > 0:
        data_train = G.ndict.getRows(data_train, 0, n_train)
        data_valid = G.ndict.getRows(data_valid, 0, n_train)

    return data_train, data_valid, data_init


问题


面经


文章

微信
公众号

扫码关注公众号