keras_test.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:wordsim 作者: recski 项目源码 文件源码
def test():
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s : " +
        "%(module)s (%(lineno)s) - %(levelname)s - %(message)s")

    data = [((f[0], f[1]), float(f[2]))
            for f in [line.strip().split("|||")
                      for line in open(sys.argv[1])]]

    print "sample data:", data[:3]

    train_data, devel_data, test_data = cut(data)

    logging.info('loading model...')
    glove_embedding = GloveEmbedding(sys.argv[2])
    logging.info('done!')
    dim = int(sys.argv[3])
    X_train = featurize(train_data, glove_embedding, dim)

    Y_train = np.array([e[1] for e in train_data])

    logging.info("Input shape: {0}".format(X_train.shape))
    print X_train[:3]
    logging.info("Label shape: {0}".format(Y_train.shape))
    print Y_train[:3]

    input_dim = X_train.shape[1]
    output_dim = 1
    model = create_model(input_dim, output_dim)
    model.fit(X_train, Y_train, nb_epoch=int(sys.argv[4]), batch_size=32)

    X_devel = featurize(devel_data, glove_embedding, dim)
    Y_devel = np.array([e[1] for e in devel_data])

    pred = model.predict_proba(X_devel, batch_size=32)
    corr = spearmanr(pred, Y_devel)
    print "Spearman's R: {0}".format(corr)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号