runNNet.py 文件源码

python
阅读 16 收藏 0 点赞 0 评论 0

项目:DRM 作者: JohnZhengHub 项目源码 文件源码
def test(net_file, data_set, label_method, model='RNN', trees=None):
    if trees is None:
        trees = tree.load_all(data_set, label_method)
    assert net_file is not None, "Must give model to test"
    print "Testing netFile %s" % net_file
    with open(net_file, 'r') as fid:
        opts = pickle.load(fid)
        _ = pickle.load(fid)

        if model == 'RNTN':
            nn = RNTN(opts.wvec_dim, opts.output_dim, opts.num_words, opts.minibatch)
        elif model == 'RNN':
            nn = RNN(opts.wvec_dim, opts.output_dim, opts.num_words, opts.minibatch)
        elif opts.model == 'TreeLSTM':
            nn = TreeLSTM(opts.wvec_dim, opts.mem_dim, opts.output_dim, opts.num_words, opts.minibatch, rho=opts.rho)
        elif opts.model == 'TreeTLSTM':
            nn = TreeTLSTM(opts.wvec_dim, opts.mem_dim, opts.output_dim, opts.num_words, opts.minibatch, rho=opts.rho)
        else:
            raise '%s is not a valid neural network so far only RNTN, RNN, RNN2, RNN3, and DCNN' % opts.model

        nn.init_params()
        nn.from_file(fid)

    print "Testing %s..." % model

    cost, correct, guess = nn.cost_and_grad(trees, test=True)
    correct_sum = 0
    for i in xrange(0, len(correct)):
        correct_sum += (guess[i] == correct[i])

    confusion = [[0 for i in range(nn.output_dim)] for j in range(nn.output_dim)]
    for i, j in zip(correct, guess): confusion[i][j] += 1
    # makeconf(confusion)

    pre, rec, f1, support = metrics.precision_recall_fscore_support(correct, guess)
    #print "Cost %f, Acc %f" % (cost, correct_sum / float(len(correct)))
    #return correct_sum / float(len(correct))
    f1 = (100*sum(f1[1:] * support[1:])/sum(support[1:]))
    print "Cost %f, F1 %f, Acc %f" % (cost, f1, correct_sum / float(len(correct)))
    return f1
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号