def __init__(self, args, attr_size, node_size):
super(TreeLM, self).__init__()
self.batch_size = args.batch_size
self.seq_length = args.seq_length
self.attr_size = attr_size
self.node_size = node_size
self.embedding_dim = args.embedding_dim
self.layer_num = args.layer_num
self.dropout_prob = args.dropout_prob
self.lr = args.lr
self.attr_embedding = nn.Embedding(self.attr_size, self.embedding_dim)
self.dropout = nn.Dropout(self.dropout_prob)
self.lstm = nn.LSTM(input_size = self.embedding_dim,
hidden_size = self.embedding_dim,
num_layers= self.layer_num,
dropout = self.dropout_prob)
self.fc = nn.Linear(self.embedding_dim, self.node_size)
self.optimizer = optim.Adam(self.parameters(), lr=self.lr)
# self.node_mapping = node_mapping
评论列表
文章目录