layer.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:hart 作者: akosiorek 项目源码 文件源码
def _forward(self, vs):
        if self.local:  # expand input patches and split by filters
            input_local_expanded = tf.extract_image_patches(self.inpt,
                                                            pad_shape(self.ksize),
                                                            self.strides,
                                                            [1, 1, 1, 1],
                                                            padding=self.padding)

            values = []
            for filt in tf.split(axis=3, num_or_size_splits=self.n_filters, value=self.filters):
                channel_i = tf.reduce_sum(tf.multiply(filt, input_local_expanded), 3,
                                          keep_dims=True)
                values.append(channel_i)
            self.output = tf.concat(axis=3, values=values)
        else:  # split by images in batch and map to regular conv2d function
            inpt = tf.expand_dims(self.inpt, 1)

            filt_shape = [-1, self.ksize[0], self.ksize[1], self.n_cin, self.n_filters]
            filt = tf.reshape(self.filters, filt_shape)
            elems = (inpt, filt)
            result = tf.map_fn(lambda x: tf.nn.conv2d(x[0], x[1],
                                                      self.strides,
                                                      self.padding), elems,
                               dtype=tf.float32, infer_shape=False)
            result = tf.squeeze(result, [1])
            result.set_shape(self.inpt.get_shape()[:-1].concatenate([self.n_filters]))
            self.output = result
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号