linear-regrssion.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:rascal-tensorflow 作者: stayrascal 项目源码 文件源码
def model(features, labels, mode, params):
    with tf.device("/cpu:0"):
        # Build a linear model and predict values
        W = tf.get_variable("W", [1], dtype=tf.float64)
        b = tf.get_variable("b", [1], dtype=tf.float64)
        y = W * features[:, 0] + b
        # Loss sub-graph
        loss = tf.reduce_sum(tf.square(y - labels))
        # Training sub-graph
        global_step = tf.train.get_global_step()
        optimizer = tf.train.GradientDescentOptimizer(0.01)
        train = tf.group(optimizer.minimize(loss),
                         tf.assign_add(global_step, 1))
        # ModelFnOps connects subgraphs we built to the
        # appropriate functionality.
        return tf.contrib.learn.estimators.model_fn.ModelFnOps(
            mode=mode, predictions=y,
            loss=loss,
            train_op=train)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号