def test(net_file, data_set, label_method, model='RNN', trees=None):
if trees is None:
trees = tree.load_all(data_set, label_method)
assert net_file is not None, "Must give model to test"
print "Testing netFile %s" % net_file
with open(net_file, 'r') as fid:
opts = pickle.load(fid)
_ = pickle.load(fid)
if model == 'RNTN':
nn = RNTN(opts.wvec_dim, opts.output_dim, opts.num_words, opts.minibatch)
elif model == 'RNN':
nn = RNN(opts.wvec_dim, opts.output_dim, opts.num_words, opts.minibatch)
elif opts.model == 'TreeLSTM':
nn = TreeLSTM(opts.wvec_dim, opts.mem_dim, opts.output_dim, opts.num_words, opts.minibatch, rho=opts.rho)
elif opts.model == 'TreeTLSTM':
nn = TreeTLSTM(opts.wvec_dim, opts.mem_dim, opts.output_dim, opts.num_words, opts.minibatch, rho=opts.rho)
else:
raise '%s is not a valid neural network so far only RNTN, RNN, RNN2, RNN3, and DCNN' % opts.model
nn.init_params()
nn.from_file(fid)
print "Testing %s..." % model
cost, correct, guess = nn.cost_and_grad(trees, test=True)
correct_sum = 0
for i in xrange(0, len(correct)):
correct_sum += (guess[i] == correct[i])
confusion = [[0 for i in range(nn.output_dim)] for j in range(nn.output_dim)]
for i, j in zip(correct, guess): confusion[i][j] += 1
# makeconf(confusion)
pre, rec, f1, support = metrics.precision_recall_fscore_support(correct, guess)
#print "Cost %f, Acc %f" % (cost, correct_sum / float(len(correct)))
#return correct_sum / float(len(correct))
f1 = (100*sum(f1[1:] * support[1:])/sum(support[1:]))
print "Cost %f, F1 %f, Acc %f" % (cost, f1, correct_sum / float(len(correct)))
return f1
评论列表
文章目录