model_seg+pos.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:tensorflow-CWS-LSTM 作者: elvinpoon 项目源码 文件源码
def pos_prediction(self):
        outputs, size, batch_size = self.outputs
        num_class = len(POS_tagging['P'])

        output_w = weight_variable([size, num_class])
        output_b = bias_variable([num_class])
        # outputs = tf.transpose(outputs,[1,0,2])
        tag_trans = weight_variable([num_class, num_class])
        outputs = tf.reverse(outputs, [True, False, False])
        def transition(previous_pred, x):
            res = tf.matmul(x, output_w) + output_b
            deviation = tf.tile(tf.expand_dims(tf.reduce_min(previous_pred, reduction_indices=1), 1),
                                [1, num_class])

            previous_pred -= deviation
            focus = 0.5
            res += tf.matmul(previous_pred, tag_trans) * focus
            prediction = tf.nn.softmax(res)
            return prediction
        # Recurrent network.
        pred = tf.scan(transition, outputs, initializer=tf.zeros([batch_size, num_class]), parallel_iterations=100)
        pred = tf.reverse(pred, [True, False, False])
        pred = tf.transpose(pred, [1, 0, 2])
        return pred
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号