function_pkl_test.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:Synkhronos 作者: astooke 项目源码 文件源码
def build_train_func(rank=0, **kwargs):
    print("rank: {} Building model".format(rank))
    resnet = build_resnet()

    print("Building training function")
    x = T.ftensor4('x')
    y = T.imatrix('y')

    prob = L.get_output(resnet['prob'], x, deterministic=False)
    loss = T.nnet.categorical_crossentropy(prob, y.flatten()).mean()
    params = L.get_all_params(resnet.values(), trainable=True)

    sgd_updates = updates.sgd(loss, params, learning_rate=1e-4)

    # make a function to compute and store the raw gradient
    f_train = theano.function(inputs=[x, y],
                              outputs=loss,  # (assumes this is an avg)
                              updates=sgd_updates)

    return f_train, "original"
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号