def build_multi_dynamic_brnn(args,
maxTimeSteps,
inputX,
cell_fn,
seqLengths,
time_major=True):
hid_input = inputX
for i in range(args.num_layer):
scope = 'DBRNN_' + str(i + 1)
forward_cell = cell_fn(args.num_hidden, activation=args.activation)
backward_cell = cell_fn(args.num_hidden, activation=args.activation)
# tensor of shape: [max_time, batch_size, input_size]
outputs, output_states = bidirectional_dynamic_rnn(forward_cell, backward_cell,
inputs=hid_input,
dtype=tf.float32,
sequence_length=seqLengths,
time_major=True,
scope=scope)
# forward output, backward ouput
# tensor of shape: [max_time, batch_size, input_size]
output_fw, output_bw = outputs
# forward states, backward states
output_state_fw, output_state_bw = output_states
# output_fb = tf.concat(2, [output_fw, output_bw])
output_fb = tf.concat([output_fw, output_bw], 2)
shape = output_fb.get_shape().as_list()
output_fb = tf.reshape(output_fb, [shape[0], shape[1], 2, int(shape[2] / 2)])
hidden = tf.reduce_sum(output_fb, 2)
hidden = dropout(hidden, args.keep_prob, (args.mode == 'train'))
if i != args.num_layer - 1:
hid_input = hidden
else:
outputXrs = tf.reshape(hidden, [-1, args.num_hidden])
# output_list = tf.split(0, maxTimeSteps, outputXrs)
output_list = tf.split(outputXrs, maxTimeSteps, 0)
fbHrs = [tf.reshape(t, [args.batch_size, args.num_hidden]) for t in output_list]
return fbHrs
dynamic_brnn.py 文件源码
python
阅读 23
收藏 0
点赞 0
评论 0
评论列表
文章目录