dream.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:DREAM 作者: LaceyChen17 项目源码 文件源码
def __init__(self, config):
        super(DreamModel, self).__init__()
        # Model configuration
        self.config = config
        # Layer definitons
        self.encode = torch.nn.Embedding(config.num_product, 
                                         config.embedding_dim,
                                         padding_idx = 0) # Item embedding layer, ????
        self.pool = {'avg':pool_avg, 'max':pool_max}[config.basket_pool_type] # Pooling of basket
        # RNN type specify
        if config.rnn_type in ['LSTM', 'GRU']:
            self.rnn = getattr(torch.nn, config.rnn_type)(config.embedding_dim, 
                                                          config.embedding_dim, 
                                                          config.rnn_layer_num, 
                                                          batch_first=True, 
                                                          dropout=config.dropout)
        else:
            nonlinearity = {'RNN_TANH': 'tanh', 'RNN_RELU': 'relu'}[config.rnn_type]
            self.rnn = torch.nn.RNN(config.embedding_dim, 
                                    config.embedding_dim, 
                                    config.rnn_layer_num, 
                                    nonlinearity=nonlinearity, 
                                    batch_first=True, 
                                    dropout=config.dropout)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号