basic.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:aspect_adversarial 作者: yuanzh 项目源码 文件源码
def __init__(self, n_in, n_out, activation=tanh,
            order=1, clip_gradients=False, BN=False):

        self.n_in = n_in
        self.n_out = n_out
        self.activation = activation
        self.order = order
        self.clip_gradients = clip_gradients

        # batch, in, row, col
        self.input_shape = (None, n_in, 1, None)
        # out, in, row, col
        self.filter_shape = (n_out, n_in, 1, order)
        self.W = create_shared(random_init(self.filter_shape), name="W")
        if not BN:
            self.bias = create_shared(random_init((n_out,)), name="bias")

        self.BNLayer = None
        self.BN = BN
        if BN:
            # calculate appropriate input_shape, (mini_batch_size, # of channel, # row, # column)
            new_shape = list(self.input_shape)
            new_shape[1] = self.filter_shape[0]
            new_shape = tuple(new_shape)
            self.BNLayer = BatchNormalization(new_shape, mode=1)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号