gated_cnn.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:eva 作者: israelg99 项目源码 文件源码
def __call__(self, model1, model2=None):
        if model2 is None:
            h_model = model1
            filter_size = (7, 7)
        else:
            h_model = model2
            filter_size = (3, 3)

        v_model = PaddedConvolution2D(self.filters, filter_size, 'vertical')(model1)
        feed_vertical = FeedVertical(self.filters)(v_model)
        v_model = GatedBlock(self.filters, h=self.h)(v_model)

        h_model_new = PaddedConvolution2D(self.filters, filter_size, 'horizontal', 'A')(h_model)
        h_model_new = GatedBlock(self.filters, v=feed_vertical, h=self.h, crop_right=True)(h_model_new)
        h_model_new = Convolution2D(self.filters, 1, 1, border_mode='valid')(h_model_new)

        return (v_model, h_model_new if model2 is None else Merge(mode='sum')([h_model_new, h_model]))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号