main.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:DeepEnhancer 作者: minxueric 项目源码 文件源码
def main():
    ################
    # LOAD DATASET #
    ################
    dataset = './data/ubiquitous_aug.hkl'
    kfd = './data/ubiquitous_kfold.hkl'
    print('Loading dataset {}...'.format(dataset))
    X, y = hkl.load(open(dataset, 'r'))
    X = X.reshape(-1, 4, 1, 400).astype(floatX)
    y = y.astype('int32')
    print('X shape: {}, y shape: {}'.format(X.shape, y.shape))
    kf = hkl.load(open(kfd, 'r'))
    kfold = [(train, test) for train, test in kf]
    (train, test) = kfold[0]
    print('train_set size: {}, test_set size: {}'.format(len(train), len(test)))
    # shuffle +/- labels in minibatch
    print('shuffling train_set and test_set')
    shuffle(train)
    shuffle(test)
    X_train = X[train]
    X_test = X[test]
    y_train = y[train]
    y_test = y[test]
    print('data prepared!')

    layers = [
            (InputLayer, {'shape': (None, 4, 1, 400)}),
            (Conv2DLayer, {'num_filters': 64, 'filter_size': (1, 4)}),
            (Conv2DLayer, {'num_filters': 64, 'filter_size': (1, 3)}),
            (Conv2DLayer, {'num_filters': 64, 'filter_size': (1, 3)}),
            (MaxPool2DLayer, {'pool_size': (1, 2)}),
            (Conv2DLayer, {'num_filters': 32, 'filter_size': (1, 2)}),
            (Conv2DLayer, {'num_filters': 32, 'filter_size': (1, 2)}),
            (Conv2DLayer, {'num_filters': 32, 'filter_size': (1, 2)}),
            (MaxPool2DLayer, {'pool_size': (1, 2)}),
            (DenseLayer, {'num_units': 64}),
            (DropoutLayer, {}),
            (DenseLayer, {'num_units': 64}),
            (DenseLayer, {'num_units': 2, 'nonlinearity': softmax})]

    net = NeuralNet(
            layers=layers,
            max_epochs=100,
            update=adam,
            update_learning_rate=1e-4,
            train_split=TrainSplit(eval_size=0.1),
            on_epoch_finished=[
                AdjustVariable(1e-4, target=0, half_life=20)],
            verbose=2)

    net.fit(X_train, y_train)
    plot_loss(net)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号