dssm.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:recom-system 作者: tizot 项目源码 文件源码
def build_multi_dssm(input_var=None, num_samples=None, num_entries=6, num_ngrams=42**3, num_hid1=300, num_hid2=300, num_out=128):
    """Builds a DSSM structure in a Lasagne/Theano way.

    The built DSSM is the neural network that computes the projection of only one paper.
    The input ``input_var`` should have two dimensions: (``num_samples * num_entries``, ``num_ngrams``).
    The output is then computed in a batch way: one paper at a time, but all papers from the same sample in the dataset are grouped
    (cited papers, citing papers and ``num_entries - 2`` irrelevant papers).

    Args:
        input_var (:class:`theano.tensor.TensorType` or None): symbolic input variable of the DSSM
        num_samples (int): the number of samples in the batch input dataset (number of rows)
        num_entries (int): the number of compared papers in the DSSM structure
        num_ngrams (int): the size of the vocabulary
        num_hid1 (int): the number of units in the first hidden layer
        num_hid2 (int): the number of units in the second hidden layer
        num_out (int): the number of units in the output layer

    Returns:
        :class:`lasagne.layers.Layer`: the output layer of the DSSM
    """

    assert (num_entries > 2)

    # Initialise input layer
    if num_samples is None:
        num_rows = None
    else:
        num_rows = num_samples * num_entries

    l_in = layers.InputLayer(shape=(num_rows, num_ngrams), input_var=input_var)

    # Initialise the hidden and output layers or the DSSM
    l_hid1 = layers.DenseLayer(l_in, num_units=num_hid1, nonlinearity=nonlinearities.tanh, W=init.GlorotUniform())
    l_hid2 = layers.DenseLayer(l_hid1, num_units=num_hid2, nonlinearity=nonlinearities.tanh, W=init.GlorotUniform())
    l_out = layers.DenseLayer(l_hid2, num_units=num_out, nonlinearity=nonlinearities.tanh, W=init.GlorotUniform())

    l_out = layers.ExpressionLayer(l_out, lambda X: X / X.norm(2), output_shape='auto')

    return l_out
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号