def __call__(self, inputs, state, scope=None):
"""Attention GRU with nunits cells."""
with vs.variable_scope(scope or "attention_gru_cell"):
with vs.variable_scope("gates"): # Reset gate and update gate.
# We start with bias of 1.0 to not reset and not update.
if inputs.get_shape()[-1] != self._num_units + 1:
raise ValueError("Input should be passed as word input concatenated with 1D attention on end axis")
# extract input vector and attention
inputs, g = array_ops.split(inputs,
num_or_size_splits=[self._num_units,1],
axis=1)
r = _linear([inputs, state], self._num_units, True)
r = sigmoid(r)
with vs.variable_scope("candidate"):
r = r*_linear(state, self._num_units, False)
with vs.variable_scope("input"):
x = _linear(inputs, self._num_units, True)
h_hat = self._activation(r + x)
new_h = (1 - g) * state + g * h_hat
return new_h, new_h
attention_gru_cell.py 文件源码
python
阅读 18
收藏 0
点赞 0
评论 0
评论列表
文章目录