train_celeba_classifier.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:discgen 作者: vdumoulin 项目源码 文件源码
def create_training_computation_graphs():
    x = tensor.tensor4('features')
    y = tensor.imatrix('targets')

    convnet, mlp = create_model_bricks()
    y_hat = mlp.apply(convnet.apply(x).flatten(ndim=2))
    cost = BinaryCrossEntropy().apply(y, y_hat)
    accuracy = 1 - tensor.neq(y > 0.5, y_hat > 0.5).mean()
    cg = ComputationGraph([cost, accuracy])

    # Create a graph which uses batch statistics for batch normalization
    # as well as dropout on selected variables
    bn_cg = apply_batch_normalization(cg)
    bricks_to_drop = ([convnet.layers[i] for i in (5, 11, 17)] +
                      [mlp.application_methods[1].brick])
    variables_to_drop = VariableFilter(
        roles=[OUTPUT], bricks=bricks_to_drop)(bn_cg.variables)
    bn_dropout_cg = apply_dropout(bn_cg, variables_to_drop, 0.5)

    return cg, bn_dropout_cg
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号