modules.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:ParlAI 作者: facebookresearch 项目源码 文件源码
def __init__(self, num_features, padding_idx=0, rnn_class='lstm',
                 emb_size=128, hidden_size=128, num_layers=2, dropout=0.1,
                 bidir_input=False, share_output=True,
                 attn_type='none', attn_length=-1):
        super().__init__()

        if padding_idx != 0:
            raise RuntimeError('This module\'s output layer needs to be fixed '
                               'if you want a padding_idx other than zero.')

        self.dropout = dropout
        self.layers = num_layers
        self.hsz = hidden_size

        self.lt = nn.Embedding(num_features, emb_size, padding_idx=padding_idx)
        self.rnn = rnn_class(emb_size, hidden_size, num_layers,
                             dropout=dropout, batch_first=True)

        # rnn output to embedding
        self.o2e = nn.Linear(hidden_size, emb_size)
        # embedding to scores, use custom linear to possibly share weights
        shared_weight = self.lt.weight if share_output else None
        self.e2s = Linear(emb_size, num_features, bias=False,
                          shared_weight=shared_weight)
        self.shared = shared_weight is not None

        self.attn_type = attn_type
        self.attention = AttentionLayer(attn_type=attn_type,
                                        hidden_size=hidden_size,
                                        emb_size=emb_size,
                                        bidirectional=bidir_input,
                                        attn_length=attn_length)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号