def build(self, input_shape):
super().build(input_shape)
self.mask = np.ones(self.W_shape)
assert mask.shape[0] == mask.shape[1]
filter_size = self.mask.shape[0]
filter_center = filter_size / 2
self.mask[math.ceil(filter_center):] = 0
self.mask[math.floor(filter_center):, math.ceil(filter_center):] = 0
if self.mono:
if self.mask_type == 'A':
self.mask[math.floor(filter_center), math.floor(filter_center)] = 0
else:
op = np.greater_equal if self.mask_type == 'A' else np.greater
for i in range(self.n_channels):
for j in range(self.n_channels):
if op(i, j):
self.mask[math.floor(filter_center), math.floor(filter_center), i::self.n_channels, j::self.n_channels] = 0
self.mask = K.variable(self.mask)
评论列表
文章目录