mnist.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:tensorflow-with-go 作者: nilsmagnus 项目源码 文件源码
def main(_):
  mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)

  # GOLANG note that we must label the input-tensor!
  x = tf.placeholder(tf.float32, [None, 784], name="imageinput")
  W = tf.Variable(tf.zeros([784, 10]))
  b = tf.Variable(tf.zeros([10]))
  y = tf.add(tf.matmul(x, W) , b)

  y_ = tf.placeholder(tf.float32, [None, 10])

  cross_entropy = tf.reduce_mean(
    tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
  train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

  sess = tf.InteractiveSession()
  tf.global_variables_initializer().run()

  # Train
  for _ in range(1000):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})


  # GOLANG note that we must label the infer-operation!!
  infer = tf.argmax(y,1, name="infer")

  correct_prediction = tf.equal(infer, tf.argmax(y_, 1))
  accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  print(sess.run(accuracy, feed_dict={x: mnist.test.images,
                                      y_: mnist.test.labels}))

  builder = tf.saved_model.builder.SavedModelBuilder("mnistmodel")

  # GOLANG note that we must tag our model so that we can retrieve it at inference-time
  builder.add_meta_graph_and_variables(sess,["serve"])

  builder.save()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号