train.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:DREAM 作者: LaceyChen17 项目源码 文件源码
def evaluate_dream():
    dr_model.eval()
    dr_hidden = dr_model.init_hidden(dr_config.batch_size) 

    total_loss = 0
    start_time = time()
    num_batchs = ceil(len(test_ub) / dr_config.batch_size)
    for i,x in enumerate(batchify(test_ub, dr_config.batch_size)):
        baskets, lens, _ = x
        dynamic_user, _  = dr_model(baskets, lens, dr_hidden)
        loss = bpr_loss(baskets, dynamic_user, dr_model.encode.weight, dr_config)
        dr_hidden = repackage_hidden(dr_hidden)
        total_loss += loss.data

    # Logging
    elapsed = (time() - start_time) * 1000 / num_batchs
    total_loss = total_loss[0] / num_batchs
    print('[Evaluation]| Epochs {:3d} | Elapsed {:02.2f} | Loss {:05.2f} |'.format(epoch, elapsed, total_loss))
    return total_loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号