fxnn.py 文件源码

python
阅读 43 收藏 0 点赞 0 评论 0

项目:fxnn 作者: khaotik 项目源码 文件源码
def build_model(model_):
    global fn_predict, fn_record
    global g_ozer, g_mdl

    g_ozer = dict(simple=VanillaSGD, adam=AdamSGD)[OZER]()
    g_ozer.lr = LEARN_RATE

    s_x = T.tensor4('x')
    s_y = T.ivector('y')
    s_pdpo = T.scalar()
    s_out = model_(s_x, s_pdpo)

    s_y_onehot = T.extra_ops.to_one_hot(s_y, len(g_dataset.label_map))
    s_loss = T.mean(-s_y_onehot*T.log(s_out + 1e-3))
    s_accr = T.mean( T.switch(
            T.eq(T.argmax(s_out, axis=1), T.argmax(s_y_onehot, axis=1)), 1, 0))

    no_dropout = [(s_pdpo, T.constant(0., dtype=th.config.floatX))]
    fn_predict = th.function(
        [s_x, s_y],
        {'pred':s_out, 'accr':s_accr, 'loss':s_loss},
        givens=no_dropout, profile=PROFILE)
    rec_fetches = {
        'x': s_x, 'y': s_y,
        'pred': s_out}
    rec_fetches.update(g_mdl.params_di)
    fn_record = th.function(
        [s_x, s_y], rec_fetches, givens=no_dropout, profile=PROFILE)
    g_ozer.compile(
        [s_x, s_y],
        s_loss,
        g_mdl.params_di.values(),
        fetches_={'pred': s_out, 'loss': s_loss, 'accr': s_accr},
        givens_=[(s_pdpo, T.constant(TRAIN_PDPO, dtype=th.config.floatX))],
        profile_=PROFILE)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号