train.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:comprehend 作者: Fenugreek 项目源码 文件源码
def get_label_costs(coder, dataset, labels, batch_size=100):
    """
    Return average cross entropy loss and class error rate on
    dataset by coder object with its current weights.
    """

    n_batches = dataset.shape[0] // batch_size
    error = 0.
    cost = 0.
    for index in range(n_batches):
        batch = dataset[index * batch_size : (index+1) * batch_size]
        labels_batch = labels[index * batch_size : (index+1) * batch_size]
        predicted = coder.get_hidden_values(batch)

        loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=predicted,
                                                              labels=labels_batch)
        cost += tf.reduce_mean(loss).eval()

        bad_prediction = tf.not_equal(tf.argmax(predicted , 1), labels_batch)
        error += tf.reduce_mean(tf.cast(bad_prediction, tf.float32)).eval()

    return (cost / n_batches, error / n_batches)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号