def forward(self, x):
# x = (sequence length, batch_size, dimension of embedding)
text = x.text
batch_size = text.size()[1]
x = self.embed(text)
if self.config.relation_prediction_mode.upper() == "LSTM":
# h0 / c0 = (layer*direction, batch_size, hidden_dim)
if self.config.cuda:
h0 = Variable(torch.zeros(self.config.num_layer * 2, batch_size,
self.config.hidden_size).cuda())
c0 = Variable(torch.zeros(self.config.num_layer * 2, batch_size,
self.config.hidden_size).cuda())
else:
h0 = Variable(torch.zeros(self.config.num_layer * 2, batch_size,
self.config.hidden_size))
c0 = Variable(torch.zeros(self.config.num_layer * 2, batch_size,
self.config.hidden_size))
# output = (sentence length, batch_size, hidden_size * num_direction)
# ht = (layer*direction, batch, hidden_dim)
# ct = (layer*direction, batch, hidden_dim)
outputs, (ht, ct) = self.lstm(x, (h0, c0))
tags = self.hidden2tag(ht[-2:].transpose(0, 1).contiguous().view(batch_size, -1))
scores = F.log_softmax(tags)
return scores
elif self.config.relation_prediction_mode.upper() == "GRU":
if self.config.cuda:
h0 = Variable(torch.zeros(self.config.num_layer * 2, batch_size,
self.config.hidden_size).cuda())
else:
h0 = Variable(torch.zeros(self.config.num_layer * 2, batch_size,
self.config.hidden_size))
outputs, ht = self.gru(x, h0)
tags = self.hidden2tag(ht[-2:].transpose(0, 1).contiguous().view(batch_size, -1))
scores = F.log_softmax(tags)
return scores
elif self.config.relation_prediction_mode.upper() == "CNN":
x = x.transpose(0, 1).contiguous().unsqueeze(1) # (batch, channel_input, sent_len, embed_dim)
x = [F.relu(self.conv1(x)).squeeze(3), F.relu(self.conv2(x)).squeeze(3), F.relu(self.conv3(x)).squeeze(3)]
# (batch, channel_output, ~=sent_len) * Ks
x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] # max-over-time pooling
# (batch, channel_output) * Ks
x = torch.cat(x, 1) # (batch, channel_output * Ks)
x = self.dropout(x)
logit = self.fc1(x) # (batch, target_size)
scores = F.log_softmax(logit)
return scores
else:
print("Unknown Mode")
exit(1)
评论列表
文章目录