merge.py 文件源码

python
阅读 39 收藏 0 点赞 0 评论 0

项目:TensorGraph 作者: hycis 项目源码 文件源码
def _train_fprop(self, state_list):
        '''The softmax is apply to n units that is not masked specified by the
           seqlen.
           state_list : [state_below, seqlen]
                state_below (2d tf tensor): shape = [batchsize, layer_dim]
                seqlen (1d tf tensor): shape = [batchsize]
                example:
                    state_below = 3 x 5 matrix
                    seqlen = [2, 1, 4]
        '''
        assert len(state_list) == 2
        state_below, seqlen = state_list
        assert len(seqlen.get_shape()) == 1
        shape = state_below.get_shape()
        assert len(shape) == 2, 'state below dimenion {} != 2'.format(len(shape))
        mask = tf.to_float(tf.sequence_mask(seqlen, shape[-1]))
        exp = tf.exp(state_below) * mask
        exp_sum = tf.reduce_sum(exp, axis=1)
        zeros = tf.to_float(tf.equal(exp_sum, 0))
        softmax = tf.div(exp, tf.expand_dims(exp_sum + zeros, -1))
        nonzeros = tf.to_float(tf.not_equal(exp_sum, 0))
        softmax = softmax * tf.expand_dims(nonzeros, -1)
        return softmax
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号