lstm_attention.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:pytorch-seq2seq 作者: rowanz 项目源码 文件源码
def __init__(self, input_size, hidden_size, use_embedding=False, use_cnn=False, vocab_size=None,
                 pad_idx=None):
        """
        Bidirectional GRU for encoding sequences
        :param input_size: Size of the feature dimension (or, if use_embedding=True, the embed dim)
        :param hidden_size: Size of the GRU hidden layer. Outputs will be hidden_size*2
        :param use_embedding: True if we need to embed the sequences
        :param vocab_size: Size of vocab (only used if use_embedding=True)
        """

        super(EncoderRNN, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.gru = nn.GRU(input_size, hidden_size, bidirectional=True)

        self.use_embedding = use_embedding
        self.use_cnn = use_cnn
        self.vocab_size = vocab_size
        self.embed = None
        if self.use_embedding:
            assert self.vocab_size is not None
            self.pad = pad_idx
            self.embed = nn.Embedding(self.vocab_size, self.input_size, padding_idx=pad_idx)
        elif self.use_cnn:
            self.embed = models.resnet50(pretrained=True)

            for param in self.embed.parameters():
                param.requires_grad = False
            self.embed.fc = nn.Linear(self.embed.fc.in_features, self.input_size)

            # Init weights (should be moved.)
            self.embed.fc.weight.data.normal_(0.0, 0.02)
            self.embed.fc.bias.data.fill_(0)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号