representations.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:spotlight 作者: maciejkula 项目源码 文件源码
def __init__(self, num_items,
                 embedding_dim=32,
                 kernel_width=3,
                 dilation=1,
                 num_layers=1,
                 nonlinearity='tanh',
                 residual_connections=True,
                 sparse=False,
                 benchmark=True,
                 item_embedding_layer=None):

        super(CNNNet, self).__init__()

        cudnn.benchmark = benchmark

        self.embedding_dim = embedding_dim
        self.kernel_width = _to_iterable(kernel_width, num_layers)
        self.dilation = _to_iterable(dilation, num_layers)
        if nonlinearity == 'tanh':
            self.nonlinearity = F.tanh
        elif nonlinearity == 'relu':
            self.nonlinearity = F.relu
        else:
            raise ValueError('Nonlinearity must be one of (tanh, relu)')
        self.residual_connections = residual_connections

        if item_embedding_layer is not None:
            self.item_embeddings = item_embedding_layer
        else:
            self.item_embeddings = ScaledEmbedding(num_items, embedding_dim,
                                                   padding_idx=PADDING_IDX,
                                                   sparse=sparse)

        self.item_biases = ZeroEmbedding(num_items, 1, sparse=sparse,
                                         padding_idx=PADDING_IDX)

        self.cnn_layers = [
            nn.Conv2d(embedding_dim,
                      embedding_dim,
                      (_kernel_width, 1),
                      dilation=(_dilation, 1)) for
            (_kernel_width, _dilation) in zip(self.kernel_width,
                                              self.dilation)
        ]

        for i, layer in enumerate(self.cnn_layers):
            self.add_module('cnn_{}'.format(i),
                            layer)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号