layers.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:rnn-theano 作者: wangxggc 项目源码 文件源码
def get_output(self, input):
        """
        input is

        :param input: A 4-D tensor with shape(batch_size, channels, sen_len, embedding_size),
                      usually, embedding_size == filter_width
        :return: A 4-D tensor with shape(batch_size, filter_size, sen_len-filter_height+1, embedding_size-filter_width+1)
        """
        # usually output is a 4-D tensor with shape(batch_size, filters, sen_len-filter_height+1, 1)
        output = T.nnet.conv2d(input=input,
                      filters=self.params[self.id + "conv_w"],
                      input_shape=self.input_shape,
                      filter_shape=self.filter_shape,
                      border_mode="valid")
        #  output = output.reshape([self.batch_size, self.filter_size, self.pooling_shape[0], self.pooling_shape[1]])
        # add a bias to each filter
        output += self.params[self.id + "conv_b"].dimshuffle("x", 0, "x", "x")

        if self.pooling_mode != "average": #self.pooling_mode == "max":
            output = pool.pool_2d(input=output,
                                 ignore_border=True,
                                 ds=self.pooling_shape,
                                 st=self.pooling_shape,
                                 padding=(0, 0),    # padding shape
                                 mode="max")
            # output = theano.printing.Print("Conv Pool Out")(output)
            return output.flatten().reshape([self.batch_size, self.filter_size])
        elif self.pooling_mode == "average":
            output = pool.pool_2d(input=output,
                                 ignore_border=True,
                                 ds=self.pooling_shape,
                                 st=self.pooling_shape,
                                 padding=(0, 0),    # padding shape
                                 mode="average_inc_pad")

            return output.flatten().reshape([self.batch_size, self.filter_size])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号