mnist_model.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:dataset 作者: analysiscenter 项目源码 文件源码
def static_nn():
        input_images = tf.placeholder("uint8", [None, 28, 28, 1])
        input_labels = tf.placeholder("uint8", [None])

        input_vectors = tf.cast(tf.reshape(input_images, [-1, 28 * 28]), 'float')
        layer1 = tf.layers.dense(input_vectors, units=512, activation=tf.nn.relu)
        layer2 = tf.layers.dense(layer1, units=256, activation=tf.nn.relu)
        model_output = tf.layers.dense(layer2, units=10)
        encoded_labels = tf.one_hot(input_labels, depth=10)

        cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=encoded_labels, logits=model_output))
        optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(cost)

        prediction = tf.argmax(model_output, 1)
        correct_prediction = tf.equal(prediction, tf.argmax(encoded_labels, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float'))

        return [[input_images, input_labels], [optimizer, cost, accuracy]]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号