def __call__(self, inputs, state, scope=None):
"""Long short-term memory cell with attention (LSTMA)."""
with vs.variable_scope(scope or type(self).__name__):
if self._state_is_tuple:
state, attns, attn_states = state
else:
states = state
state = array_ops.slice(states, [0, 0], [-1, self._cell.state_size])
attns = array_ops.slice(
states, [0, self._cell.state_size], [-1, self._attn_size])
attn_states = array_ops.slice(
states, [0, self._cell.state_size + self._attn_size],
[-1, self._attn_size * self._attn_length])
attn_states = array_ops.reshape(attn_states,
[-1, self._attn_length, self._attn_size])
input_size = self._input_size
if input_size is None:
input_size = inputs.get_shape().as_list()[1]
inputs = _linear([inputs, attns], input_size, True)
lstm_output, new_state = self._cell(inputs, state)
if self._state_is_tuple:
new_state_cat = array_ops.concat(1, nest.flatten(new_state))
else:
new_state_cat = new_state
new_attns, new_attn_states = self._attention(new_state_cat, attn_states)
with vs.variable_scope("AttnOutputProjection"):
output = _linear([lstm_output, new_attns], self._attn_size, True)
new_attn_states = array_ops.concat(1, [new_attn_states,
array_ops.expand_dims(output, 1)])
new_attn_states = array_ops.reshape(
new_attn_states, [-1, self._attn_length * self._attn_size])
new_state = (new_state, new_attns, new_attn_states)
if not self._state_is_tuple:
new_state = array_ops.concat(1, list(new_state))
return output, new_state
评论列表
文章目录