def _forward(self, vs):
if self.local: # expand input patches and split by filters
input_local_expanded = tf.extract_image_patches(self.inpt,
pad_shape(self.ksize),
self.strides,
[1, 1, 1, 1],
padding=self.padding)
values = []
for filt in tf.split(axis=3, num_or_size_splits=self.n_filters, value=self.filters):
channel_i = tf.reduce_sum(tf.multiply(filt, input_local_expanded), 3,
keep_dims=True)
values.append(channel_i)
self.output = tf.concat(axis=3, values=values)
else: # split by images in batch and map to regular conv2d function
inpt = tf.expand_dims(self.inpt, 1)
filt_shape = [-1, self.ksize[0], self.ksize[1], self.n_cin, self.n_filters]
filt = tf.reshape(self.filters, filt_shape)
elems = (inpt, filt)
result = tf.map_fn(lambda x: tf.nn.conv2d(x[0], x[1],
self.strides,
self.padding), elems,
dtype=tf.float32, infer_shape=False)
result = tf.squeeze(result, [1])
result.set_shape(self.inpt.get_shape()[:-1].concatenate([self.n_filters]))
self.output = result
评论列表
文章目录