def forward(self, embed, state_word):
# embeddings
embedded = self.lookup(embed)
# word level gru
output_word, state_word = self.word_gru(embedded, state_word)
word_squish = batch_matmul_bias(output_word, self.weight_W_word, self.bias_word, nonlinearity='tanh')
word_attn = batch_matmul(word_squish, self.weight_proj_word)
word_attn_norm = self.softmax_word(word_attn.transpose(1,0))
word_attn_vectors = attention_mul(output_word, word_attn_norm.transpose(1,0))
return word_attn_vectors, state_word, word_attn_norm
评论列表
文章目录