def _pos_private(self, encoder_units, config, is_training):
"""Decode model for pos
Args:
encoder_units - these are the encoder units
num_pos - the number of pos tags there are (output units)
returns:
logits
"""
with tf.variable_scope("pos_decoder"):
pos_decoder_cell = rnn.BasicLSTMCell(config.pos_decoder_size,
forget_bias=1.0, reuse=tf.get_variable_scope().reuse)
if is_training and config.keep_prob < 1:
pos_decoder_cell = rnn.DropoutWrapper(
pos_decoder_cell, output_keep_prob=config.keep_prob)
encoder_units = tf.transpose(encoder_units, [1, 0, 2])
decoder_outputs, decoder_states = tf.nn.dynamic_rnn(pos_decoder_cell,
encoder_units,
dtype=tf.float32,
scope="pos_rnn")
output = tf.reshape(tf.concat(decoder_outputs, 1),
[-1, config.pos_decoder_size])
softmax_w = tf.get_variable("softmax_w",
[config.pos_decoder_size,
config.num_pos_tags])
softmax_b = tf.get_variable("softmax_b", [config.num_pos_tags])
logits = tf.matmul(output, softmax_w) + softmax_b
return logits, decoder_states
评论列表
文章目录