model.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:torch_light 作者: ne7ermore 项目源码 文件源码
def __init__(self, args):
        super().__init__()

        for k, v in args.__dict__.items():
            self.__setattr__(k, v)

        self.num_directions = 2 if self.bidirectional else 1
        self.lookup_table = nn.Embedding(self.vocab_size, self.embed_dim)
        self.lstm = nn.LSTM(self.embed_dim,
                    self.hidden_size,
                    self.lstm_layers,
                    batch_first=True,
                    dropout=self.dropout,
                    bidirectional=self.bidirectional)
        self.lr = nn.Linear(self.hidden_size*self.num_directions,
                        self.vocab_size)

        self._init_weights()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号