def forward(self, input_vb):
# NOTE: the operation order must be the following: control, access{write, read}, output
# 1. first feed {input, read_vec_{t-1}} to controller
hidden_vb = self.controller.forward(input_vb, self.read_vec_vb)
# 2. then we write to memory_{t-1} to get memory_{t}; then read from memory_{t} to get read_vec_{t}
self.read_vec_vb = self.accessor.forward(hidden_vb)
# 3. finally we concat the output from the controller and the current read_vec_{t} to get the final output
output_vb = self.hid_to_out(torch.cat((hidden_vb.view(-1, self.hidden_dim),
self.read_vec_vb.view(-1, self.read_vec_dim)), 1))
# we clip the output values here
return F.sigmoid(torch.clamp(output_vb, min=-self.clip_value, max=self.clip_value)).view(1, self.batch_size, self.output_dim)
评论列表
文章目录