relation_prediction.py 文件源码

python
阅读 15 收藏 0 点赞 0 评论 0

项目:BuboQA 作者: castorini 项目源码 文件源码
def forward(self, x):
        # x = (sequence length, batch_size, dimension of embedding)
        text = x.text
        batch_size = text.size()[1]
        x = self.embed(text)
        if self.config.relation_prediction_mode.upper() == "LSTM":
            # h0 / c0 = (layer*direction, batch_size, hidden_dim)
            if self.config.cuda:
                h0 = Variable(torch.zeros(self.config.num_layer * 2, batch_size,
                                          self.config.hidden_size).cuda())
                c0 = Variable(torch.zeros(self.config.num_layer * 2, batch_size,
                                          self.config.hidden_size).cuda())
            else:
                h0 = Variable(torch.zeros(self.config.num_layer * 2, batch_size,
                                          self.config.hidden_size))
                c0 = Variable(torch.zeros(self.config.num_layer * 2, batch_size,
                                          self.config.hidden_size))
            # output = (sentence length, batch_size, hidden_size * num_direction)
            # ht = (layer*direction, batch, hidden_dim)
            # ct = (layer*direction, batch, hidden_dim)
            outputs, (ht, ct) = self.lstm(x, (h0, c0))
            tags = self.hidden2tag(ht[-2:].transpose(0, 1).contiguous().view(batch_size, -1))
            scores = F.log_softmax(tags)
            return scores
        elif self.config.relation_prediction_mode.upper() == "GRU":
            if self.config.cuda:
                h0 = Variable(torch.zeros(self.config.num_layer * 2, batch_size,
                                          self.config.hidden_size).cuda())
            else:
                h0 = Variable(torch.zeros(self.config.num_layer * 2, batch_size,
                                          self.config.hidden_size))
            outputs, ht = self.gru(x, h0)

            tags = self.hidden2tag(ht[-2:].transpose(0, 1).contiguous().view(batch_size, -1))
            scores = F.log_softmax(tags)
            return scores
        elif self.config.relation_prediction_mode.upper() == "CNN":
            x = x.transpose(0, 1).contiguous().unsqueeze(1)  # (batch, channel_input, sent_len, embed_dim)
            x = [F.relu(self.conv1(x)).squeeze(3), F.relu(self.conv2(x)).squeeze(3), F.relu(self.conv3(x)).squeeze(3)]
            # (batch, channel_output, ~=sent_len) * Ks
            x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x]  # max-over-time pooling
            # (batch, channel_output) * Ks
            x = torch.cat(x, 1)  # (batch, channel_output * Ks)
            x = self.dropout(x)
            logit = self.fc1(x)  # (batch, target_size)
            scores = F.log_softmax(logit)
            return scores
        else:
            print("Unknown Mode")
            exit(1)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号