MatchTensor.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:MatchZoo 作者: faneshion 项目源码 文件源码
def build(self, input_shape):
        # Used purely for shape validation.
        if not isinstance(input_shape, list) or len(input_shape) != 2:
            raise ValueError('A `MatchTensor` layer should be called '
                             'on a list of 2 inputs.')
        shape1 = input_shape[0]
        shape2 = input_shape[1]
        if shape1[0] != shape2[0]:
            raise ValueError(
                'Dimension incompatibility '
                '%s != %s. ' % (shape1[0], shape2[0]) +
                'Layer shapes: %s, %s' % (shape1, shape2))
        if self.init_diag:
            if shape1[2] != shape2[2]:
                raise ValueError( 'Use init_diag need same embedding shape.' )
            M_diag = np.float32(np.random.uniform(-0.05, 0.05, [self.channel, shape1[2], shape2[2]]))
            for i in range(self.channel):
                for j in range(shape1[2]):
                    M_diag[i][j][j] = 1.0
            self.M = self.add_weight( name='M', 
                                   shape=(self.channel, shape1[2], shape2[2]),
                                   initializer=M_diag,
                                   trainable=True )
        else:
            self.M = self.add_weight( name='M', 
                                   shape=(self.channel, shape1[2], shape2[2]),
                                   initializer='uniform',
                                   trainable=True )
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号