logistic_regression.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:tf229 作者: knathanieltucker 项目源码 文件源码
def predict(x, model_path):
    """Predicts targets using a batch of predictors and a model trained by
    the logsitic regression train method
    Args:
      x: The covariates or factors of the model in an n by m array (n is number)
        of data points and m is number of factors
      model_path: location of the tf model file
    Raises:
      TODO
    Returns:
      a num data by 1 array of predictions
    """
    num_predictors = len(x[0])
    num_data = len(x)

    x = np.array(x)

    with tf.Graph().as_default() as _:
        X = tf.placeholder(tf.float32, [num_data, num_predictors])

        W = tf.Variable(tf.zeros([num_predictors, 1]))
        b = tf.Variable(1.0)

        saver = tf.train.Saver([W, b])

        Predictions =  tf.inv(tf.exp( -(tf.matmul(X, W) + b) ) + 1)

        with tf.Session() as sess:
            saver.restore(sess, model_path)

            predictions = sess.run([Predictions], feed_dict={X:x})

            return predictions
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号