main.py 文件源码

python
阅读 48 收藏 0 点赞 0 评论 0

项目:sdp 作者: tansey 项目源码 文件源码
def explicit_score(sess, model, dist, data, tf_X):
    logprobs = 0
    squared_err = 0
    indices = np.array(list(np.ndindex(dist._num_classes)))
    n = 0
    for X, y in data:
        for i in xrange(len(X)):
            feed_dict = test_dict(model, dist, X[i:i+1], y[i:i+1])
            feed_dict[tf_X] = X[i:i+1]
            density = sess.run(dist.density, feed_dict=feed_dict)[0]
            logprobs += np.log(density[tuple(y[i])])
            prediction = np.array([density[tuple(idx)] * idx for idx in indices]).sum(axis=0)
            squared_err += np.linalg.norm(y[i] - prediction)**2
            n += 1
    rmse = np.sqrt(squared_err / float(n))
    print 'Explicit logprobs: {0} RMSE: {1}'.format(logprobs, rmse)
    return logprobs, rmse
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号