def forward(self, xt, fc_feats, att_feats, p_att_feats, state):
prev_h = state[0][-1]
att_lstm_input = torch.cat([prev_h, fc_feats, xt], 1)
h_att, c_att = self.att_lstm(att_lstm_input, (state[0][0], state[1][0]))
att = self.attention(h_att, att_feats, p_att_feats)
lang_lstm_input = torch.cat([att, h_att], 1)
# lang_lstm_input = torch.cat([att, F.dropout(h_att, self.drop_prob_lm, self.training)], 1) ?????
h_lang, c_lang = self.lang_lstm(lang_lstm_input, (state[0][1], state[1][1]))
output = F.dropout(h_lang, self.drop_prob_lm, self.training)
state = (torch.stack([h_att, h_lang]), torch.stack([c_att, c_lang]))
return output, state
评论列表
文章目录