def __init__(self, config):
super(DreamModel, self).__init__()
# Model configuration
self.config = config
# Layer definitons
self.encode = torch.nn.Embedding(config.num_product,
config.embedding_dim,
padding_idx = 0) # Item embedding layer, ????
self.pool = {'avg':pool_avg, 'max':pool_max}[config.basket_pool_type] # Pooling of basket
# RNN type specify
if config.rnn_type in ['LSTM', 'GRU']:
self.rnn = getattr(torch.nn, config.rnn_type)(config.embedding_dim,
config.embedding_dim,
config.rnn_layer_num,
batch_first=True,
dropout=config.dropout)
else:
nonlinearity = {'RNN_TANH': 'tanh', 'RNN_RELU': 'relu'}[config.rnn_type]
self.rnn = torch.nn.RNN(config.embedding_dim,
config.embedding_dim,
config.rnn_layer_num,
nonlinearity=nonlinearity,
batch_first=True,
dropout=config.dropout)
评论列表
文章目录