def build_model(model_):
global fn_predict, fn_record
global g_ozer, g_mdl
g_ozer = dict(simple=VanillaSGD, adam=AdamSGD)[OZER]()
g_ozer.lr = LEARN_RATE
s_x = T.tensor4('x')
s_y = T.ivector('y')
s_pdpo = T.scalar()
s_out = model_(s_x, s_pdpo)
s_y_onehot = T.extra_ops.to_one_hot(s_y, len(g_dataset.label_map))
s_loss = T.mean(-s_y_onehot*T.log(s_out + 1e-3))
s_accr = T.mean( T.switch(
T.eq(T.argmax(s_out, axis=1), T.argmax(s_y_onehot, axis=1)), 1, 0))
no_dropout = [(s_pdpo, T.constant(0., dtype=th.config.floatX))]
fn_predict = th.function(
[s_x, s_y],
{'pred':s_out, 'accr':s_accr, 'loss':s_loss},
givens=no_dropout, profile=PROFILE)
rec_fetches = {
'x': s_x, 'y': s_y,
'pred': s_out}
rec_fetches.update(g_mdl.params_di)
fn_record = th.function(
[s_x, s_y], rec_fetches, givens=no_dropout, profile=PROFILE)
g_ozer.compile(
[s_x, s_y],
s_loss,
g_mdl.params_di.values(),
fetches_={'pred': s_out, 'loss': s_loss, 'accr': s_accr},
givens_=[(s_pdpo, T.constant(TRAIN_PDPO, dtype=th.config.floatX))],
profile_=PROFILE)
评论列表
文章目录