tools.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:structured-output-ae 作者: sbelharbi 项目源码 文件源码
def get_net_cost(model, cost_type, eye=True):
    """Get the train cost of the network."""
    cost = None
    if eye:
        d_eyes = (
            (model.trg[:, 37] - model.trg[:, 46])**2 +
            (model.trg[:, 37] - model.trg[:, 46])**2).T
        if cost_type == CostType.MeanSquared:
            cost = T.mean(
                T.sqr(model.output_dropout - model.trg), axis=1) / d_eyes
        elif cost_type == CostType.CrossEntropy:
            cost = T.mean(
                T.nnet.binary_crossentropy(
                    model.output_dropout, model.trg), axis=1)
        else:
            raise ValueError("cost type unknow.")
    else:
        if cost_type == CostType.MeanSquared:
            cost = T.mean(
                T.sqr(model.output_dropout - model.trg), axis=1)
        elif cost_type == CostType.CrossEntropy:
            cost = T.mean(
                T.nnet.binary_crossentropy(
                    model.output_dropout, model.trg), axis=1)
        else:
            raise ValueError("cost type unknow.")

    if model.l1 != 0.:
        cost += model.l1
    if model.l2 != 0.:
        cost += model.l2
    return cost
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号