treelm.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:Tree-LSTM-LM 作者: vgene 项目源码 文件源码
def __init__(self, args, attr_size, node_size):
        super(TreeLM, self).__init__()

        self.batch_size = args.batch_size
        self.seq_length = args.seq_length
        self.attr_size = attr_size
        self.node_size = node_size

        self.embedding_dim = args.embedding_dim
        self.layer_num = args.layer_num
        self.dropout_prob = args.dropout_prob
        self.lr = args.lr

        self.attr_embedding = nn.Embedding(self.attr_size, self.embedding_dim)
        self.dropout = nn.Dropout(self.dropout_prob)

        self.lstm = nn.LSTM(input_size = self.embedding_dim,
                            hidden_size = self.embedding_dim,
                            num_layers= self.layer_num,
                            dropout = self.dropout_prob)

        self.fc = nn.Linear(self.embedding_dim, self.node_size)
        self.optimizer = optim.Adam(self.parameters(), lr=self.lr)
        # self.node_mapping = node_mapping
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号