train.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:supervised-embedding-model 作者: sld 项目源码 文件源码
def _train(train_tensor, batch_size, neg_size, model, optimizer, sess):
    avg_loss = 0
    for batch in batch_iter(train_tensor, batch_size, True):
        for neg_batch in neg_sampling_iter(train_tensor, batch_size, neg_size):
            loss = sess.run(
                [model.loss, optimizer],
                feed_dict={model.context_batch: batch[:, 0, :],
                           model.response_batch: batch[:, 1, :],
                           model.neg_response_batch: neg_batch[:, 1, :]}
            )
            avg_loss += loss[0]
    avg_loss = avg_loss / (train_tensor.shape[0]*neg_size)
    return avg_loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号