def get_dssm():
doc_pos = mx.sym.Variable('doc_pos')
doc_neg = mx.sym.Variable('doc_neg')
data_usr = mx.sym.Variable("data_usr", stype='csr')
#with mx.AttrScope(ctx_group="cpu"):
w_usr = mx.sym.Variable('usr_weight', stype='row_sparse', shape=(USR_NUM, OUT_DIM))
# shared weights
w1 = mx.sym.Variable('fc1_doc_weight')
w2 = mx.sym.Variable('fc2_doc_weight')
w3 = mx.sym.Variable('fc3_doc_weight')
b1 = mx.sym.Variable('fc1_doc_bias')
b2 = mx.sym.Variable('fc2_doc_bias')
b3 = mx.sym.Variable('fc3_doc_bias')
def cosine(usr, doc):
dot = usr * doc
dot = mx.sym.sum_axis(dot, axis=1)
return dot
def doc_mlp(data):
fc1 = mx.sym.FullyConnected(data=data, num_hidden=num_hidden, name='fc1', weight=w1, bias=b1)
fc1 = mx.sym.Activation(data=fc1, act_type='relu')
fc2 = mx.sym.FullyConnected(data=fc1, num_hidden=num_hidden, name='fc2', weight=w2, bias=b2)
fc2 = mx.sym.Activation(data=fc2, act_type='relu')
fc3 = mx.sym.FullyConnected(data=fc2, num_hidden=OUT_DIM, name='fc3', weight=w3, bias=b3)
fc3 = mx.sym.Activation(data=fc3, act_type='relu')
fc3 = mx.sym.L2Normalization(data=fc3)
return fc3
# usr net
#with mx.AttrScope(ctx_group="cpu"):
usr1 = mx.sym.dot(data_usr, w_usr)
usr = mx.sym.L2Normalization(data=usr1)
# doc net
mlp_pos = doc_mlp(doc_pos)
mlp_neg = doc_mlp(doc_neg)
cosine_pos = cosine(usr, mlp_pos)
cosine_neg = cosine(usr, mlp_neg)
exp = mx.sym.exp(data=(cosine_neg - cosine_pos))
pred = mx.sym.log1p(data=exp)
out = mx.sym.MAERegressionOutput(data=pred, name='mae')
return out
评论列表
文章目录