def __call__(self, inputs, state, scope=None):
"""Long short-term memory cell (LSTM).
@param: inputs (batch,n)
@param state: the states and hidden unit of the two cells
"""
with tf.variable_scope(scope or type(self).__name__):
c1, c2, h1, h2 = state
# change bias argument to False since LN will add bias via shift
concat = _linear([inputs, h1, h2], 5 * self._num_units, False)
i, j, f1, f2, o = tf.split(value=concat, num_or_size_splits=5, axis=1)
# add layer normalization to each gate
i = ln(i, scope='i/')
j = ln(j, scope='j/')
f1 = ln(f1, scope='f1/')
f2 = ln(f2, scope='f2/')
o = ln(o, scope='o/')
new_c = (c1 * tf.nn.sigmoid(f1 + self._forget_bias) +
c2 * tf.nn.sigmoid(f2 + self._forget_bias) + tf.nn.sigmoid(i) *
self._activation(j))
# add layer_normalization in calculation of new hidden state
new_h = self._activation(ln(new_c, scope='new_h/')) * tf.nn.sigmoid(o)
new_state = LSTMStateTuple(new_c, new_h)
return new_h, new_state
md_lstm.py 文件源码
python
阅读 22
收藏 0
点赞 0
评论 0
评论列表
文章目录