def seg_prediction(self):
outputs, size, batch_size = self.outputs
num_class = self.config.num_class
output_w = weight_variable([size, num_class])
output_b = bias_variable([num_class])
# outputs = tf.transpose(outputs,[1,0,2])
tag_trans = weight_variable([num_class, num_class])
def transition(p, x):
res = tf.matmul(x, output_w) + output_b
# deviation = tf.tile(tf.expand_dims(tf.reduce_min(previous_pred, reduction_indices=1), 1),
# [1, num_class])
# previous_pred -= deviation
focus = 1.
res += tf.matmul(p, tag_trans) * focus
prediction = tf.nn.softmax(res)
return prediction
# Recurrent network.
pred = tf.scan(transition, outputs, initializer=tf.zeros([batch_size, num_class]), parallel_iterations=100)
pred = tf.transpose(pred, [1, 0, 2])
return pred
评论列表
文章目录