def __init__(self, args):
super().__init__()
for k, v in args.__dict__.items():
self.__setattr__(k, v)
self.num_directions = 2 if self.bidirectional else 1
self.lookup_table = nn.Embedding(self.vocab_size, self.embed_dim)
self.lstm = nn.LSTM(self.embed_dim,
self.hidden_size,
self.lstm_layers,
batch_first=True,
dropout=self.dropout,
bidirectional=self.bidirectional)
self.lr = nn.Linear(self.hidden_size*self.num_directions,
self.vocab_size)
self._init_weights()
评论列表
文章目录