train_eval.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:deep_metric_learning 作者: ronekko 项目源码 文件源码
def iterate_forward(model, epoch_iterator, normalize=False):
    xp = model.xp
    y_batches = []
    c_batches = []
    for batch in tqdm(copy.copy(epoch_iterator)):
        x_batch_data, c_batch_data = batch
        x_batch = Variable(xp.asarray(x_batch_data))
        y_batch = model(x_batch)
        if normalize:
            y_batch_data = y_batch.data / xp.linalg.norm(
                y_batch.data, axis=1, keepdims=True)
        else:
            y_batch_data = y_batch.data
        y_batches.append(y_batch_data)
        y_batch = None
        c_batches.append(c_batch_data)
    y_data = cuda.to_cpu(xp.concatenate(y_batches))
    c_data = np.concatenate(c_batches)
    return y_data, c_data


# memory friendly average accuracy for test data
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号