def bidirectional_rnn(cell_fw, cell_bw, inputs_embedded, input_lengths,
scope=None):
"""Bidirecional RNN with concatenated outputs and states"""
with tf.variable_scope(scope or "birnn") as scope:
((fw_outputs,
bw_outputs),
(fw_state,
bw_state)) = (
tf.nn.bidirectional_dynamic_rnn(cell_fw=cell_fw,
cell_bw=cell_bw,
inputs=inputs_embedded,
sequence_length=input_lengths,
dtype=tf.float32,
swap_memory=True,
scope=scope))
outputs = tf.concat((fw_outputs, bw_outputs), 2)
def concatenate_state(fw_state, bw_state):
if isinstance(fw_state, LSTMStateTuple):
state_c = tf.concat(
(fw_state.c, bw_state.c), 1, name='bidirectional_concat_c')
state_h = tf.concat(
(fw_state.h, bw_state.h), 1, name='bidirectional_concat_h')
state = LSTMStateTuple(c=state_c, h=state_h)
return state
elif isinstance(fw_state, tf.Tensor):
state = tf.concat((fw_state, bw_state), 1,
name='bidirectional_concat')
return state
elif (isinstance(fw_state, tuple) and
isinstance(bw_state, tuple) and
len(fw_state) == len(bw_state)):
# multilayer
state = tuple(concatenate_state(fw, bw)
for fw, bw in zip(fw_state, bw_state))
return state
else:
raise ValueError(
'unknown state type: {}'.format((fw_state, bw_state)))
state = concatenate_state(fw_state, bw_state)
return outputs, state
model_components.py 文件源码
python
阅读 19
收藏 0
点赞 0
评论 0
评论列表
文章目录