def call(self, x, mask=None):
stride = self.subsample_length
output_length, feature_dim, nb_filter = self.W_shape
xs = []
for i in range(output_length):
slice_length = slice(i * stride, i * stride + self.filter_length)
xs.append(K.reshape(x[:, slice_length, :], (1, -1, feature_dim)))
x_aggregate = K.concatenate(xs, axis=0)
# (output_length, batch_size, nb_filter)
output = K.batch_dot(x_aggregate, self.W)
output = K.permute_dimensions(output, (1, 0, 2))
if self.bias:
output += K.reshape(self.b, (1, output_length, nb_filter))
output = self.activation(output)
return output
评论列表
文章目录