def pos_prediction(self):
outputs, size, batch_size = self.outputs
num_class = len(POS_tagging['P'])
output_w = weight_variable([size, num_class])
output_b = bias_variable([num_class])
# outputs = tf.transpose(outputs,[1,0,2])
tag_trans = weight_variable([num_class, num_class])
outputs = tf.reverse(outputs, [True, False, False])
def transition(previous_pred, x):
res = tf.matmul(x, output_w) + output_b
deviation = tf.tile(tf.expand_dims(tf.reduce_min(previous_pred, reduction_indices=1), 1),
[1, num_class])
previous_pred -= deviation
focus = 0.5
res += tf.matmul(previous_pred, tag_trans) * focus
prediction = tf.nn.softmax(res)
return prediction
# Recurrent network.
pred = tf.scan(transition, outputs, initializer=tf.zeros([batch_size, num_class]), parallel_iterations=100)
pred = tf.reverse(pred, [True, False, False])
pred = tf.transpose(pred, [1, 0, 2])
return pred
评论列表
文章目录