def conv4d(x, weights, bias, output):
# print 'called'
assert len(x.shape) == 4 and len(output.shape) == 4
batch_size, input_channel = x.shape[:2]
output_batch_size, output_channel = output.shape[:2]
num_filters, filter_channel = weights.shape[:2]
assert batch_size == output_batch_size, '%d vs %d' % (batch_size, output_batch_size)
assert output_channel == num_filters
assert filter_channel == input_channel
# func = convolve if true_conv else correlate
for img_idx in range(batch_size):
for c in range(output_channel):
output[img_idx][c] = (correlate(x[img_idx], weights[c], mode='valid')
+ bias[c].reshape((1, 1, 1)))
# if img_idx == 0 and c == 0:
# print output[img_idx][c]
# print bias[c].reshape((1, 1, 1))
评论列表
文章目录