def forward(self, xt, fc_feats, att_feats, p_att_feats, state):
att_res = self.attention(state[0][-1], att_feats, p_att_feats)
all_input_sums = self.i2h(xt) + self.h2h(state[0][-1])
sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size)
sigmoid_chunk = F.sigmoid(sigmoid_chunk)
in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size)
forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size)
out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size)
in_transform = all_input_sums.narrow(1, 3 * self.rnn_size, 2 * self.rnn_size) + \
self.a2c(att_res)
in_transform = torch.max(\
in_transform.narrow(1, 0, self.rnn_size),
in_transform.narrow(1, self.rnn_size, self.rnn_size))
next_c = forget_gate * state[1][-1] + in_gate * in_transform
next_h = out_gate * F.tanh(next_c)
output = self.dropout(next_h)
state = (next_h.unsqueeze(0), next_c.unsqueeze(0))
return output, state
评论列表
文章目录