train.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:gconv_experiments 作者: tscohen 项目源码 文件源码
def validate(test_data, test_labels, model, batchsize, silent, gpu):
    N_test = test_data.shape[0]
    pbar = ProgressBar(0, N_test)
    sum_accuracy = 0
    sum_loss = 0

    for i in range(0, N_test, batchsize):
        x_batch = test_data[i:i + batchsize]
        y_batch = test_labels[i:i + batchsize]

        if gpu >= 0:
            x_batch = cuda.to_gpu(x_batch.astype(np.float32))
            y_batch = cuda.to_gpu(y_batch.astype(np.int32))

        x = Variable(x_batch)
        t = Variable(y_batch)
        loss, acc = model(x, t, train=False)

        sum_loss += float(cuda.to_cpu(loss.data)) * y_batch.size
        sum_accuracy += float(cuda.to_cpu(acc.data)) * y_batch.size
        if not silent:
            pbar.update(i + y_batch.size)

    return sum_loss, sum_accuracy
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号