omniglot.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:ntm-one-shot 作者: tristandeleu 项目源码 文件源码
def omniglot():
    input_var = T.tensor3('input') # input_var has dimensions (batch_size, time, input_dim)
    target_var = T.imatrix('target') # target_var has dimensions (batch_size, time) (label indices)

    # Load data
    generator = OmniglotGenerator(data_folder='./data/omniglot', batch_size=16, \
        nb_samples=5, nb_samples_per_class=10, max_rotation=0., max_shift=0, max_iter=None)

    output_var, output_var_flatten, params = memory_augmented_neural_network(input_var, \
        target_var, batch_size=generator.batch_size, nb_class=generator.nb_samples, \
        memory_shape=(128, 40), controller_size=200, input_size=20 * 20, nb_reads=4)

    cost = T.mean(T.nnet.categorical_crossentropy(output_var_flatten, target_var.flatten()))
    updates = lasagne.updates.adam(cost, params, learning_rate=1e-3)

    accuracies = accuracy_instance(T.argmax(output_var, axis=2), target_var, batch_size=generator.batch_size)

    print('Compiling the model...')
    train_fn = theano.function([input_var, target_var], cost, updates=updates)
    accuracy_fn = theano.function([input_var, target_var], accuracies)
    print('Done')

    print('Training...')
    t0 = time.time()
    all_scores, scores, accs = [], [], np.zeros(generator.nb_samples_per_class)
    try:
        for i, (example_input, example_output) in generator:
            score = train_fn(example_input, example_output)
            acc = accuracy_fn(example_input, example_output)
            all_scores.append(score)
            scores.append(score)
            accs += acc
            if i > 0 and not (i % 100):
                print('Episode %05d: %.6f' % (i, np.mean(score)))
                print(accs / 100.)
                scores, accs = [], np.zeros(generator.nb_samples_per_class)
    except KeyboardInterrupt:
        print(time.time() - t0)
        pass
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号